forked from mindspore-Ecosystem/mindspore
commit
7b8e0b72bb
|
@ -14,4 +14,4 @@ mindspore.dataset.vision.CenterCrop
|
|||
异常:
|
||||
- **TypeError** - 如果 `size` 不是int或sequence类型。
|
||||
- **ValueError** - 如果 `size` 小于或等于 0。
|
||||
- **RuntimeError** - 如果输入图像的shape不是 <H, W> 或 <H, W, C>。
|
||||
- **RuntimeError** - 如果输入图像的shape不是 <H, W> 或 <..., H, W, C>。
|
||||
|
|
|
@ -14,9 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/kernels/image/center_crop_op.h"
|
||||
|
||||
#include <string>
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
#include "minddata/dataset/kernels/data/data_utils.h"
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
#else
|
||||
|
@ -29,20 +32,7 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
const int32_t CenterCropOp::kDefWidth = 0;
|
||||
|
||||
Status CenterCropOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
std::string err_msg;
|
||||
std::string err_head = "CenterCrop: ";
|
||||
dsize_t rank = input->shape().Rank();
|
||||
err_msg +=
|
||||
(rank < 2 || rank > 3) ? "image shape is not <H,W,C> or <H,W>, but got rank: " + std::to_string(rank) + "\t" : "";
|
||||
err_msg += (crop_het_ <= 0 || crop_wid_ <= 0)
|
||||
? "crop size needs to be positive integers, but got crop height:" + std::to_string(crop_het_) +
|
||||
", crop width: " + std::to_string(crop_wid_) + "\t"
|
||||
: "";
|
||||
|
||||
CHECK_FAIL_RETURN_SYNTAX_ERROR(err_msg.length() == 0, err_head + err_msg);
|
||||
|
||||
Status CenterCropOp::CenterCropImg(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
int32_t top = crop_het_ - input->shape()[0]; // number of pixels to pad (top and bottom)
|
||||
int32_t left = crop_wid_ - input->shape()[1];
|
||||
std::shared_ptr<Tensor> pad_image;
|
||||
|
@ -69,26 +59,85 @@ Status CenterCropOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_p
|
|||
crop_het_);
|
||||
}
|
||||
|
||||
Status CenterCropOp::ConstructShape(const TensorShape &in_shape, std::shared_ptr<TensorShape> *out_shape) {
|
||||
auto in_shape_vec = in_shape.AsVector();
|
||||
const int h_index = -3, w_index = -2;
|
||||
in_shape_vec[in_shape_vec.size() + h_index] = crop_het_;
|
||||
in_shape_vec[in_shape_vec.size() + w_index] = crop_wid_;
|
||||
|
||||
*out_shape = std::make_shared<TensorShape>(in_shape_vec);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CenterCropOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
std::string err_msg;
|
||||
std::string err_head = "CenterCrop: ";
|
||||
dsize_t rank = input->shape().Rank();
|
||||
|
||||
err_msg += (rank < kMinImageRank)
|
||||
? "input tensor should have at least 2 dimensions, but got: " + std::to_string(rank) + "\t"
|
||||
: "";
|
||||
err_msg += (crop_het_ <= 0 || crop_wid_ <= 0)
|
||||
? "crop size needs to be positive integers, but got crop height:" + std::to_string(crop_het_) +
|
||||
", crop width: " + std::to_string(crop_wid_) + "\t"
|
||||
: "";
|
||||
CHECK_FAIL_RETURN_SYNTAX_ERROR(err_msg.length() == 0, err_head + err_msg);
|
||||
|
||||
if (rank <= kDefaultImageRank) { // images
|
||||
RETURN_IF_NOT_OK(CenterCropImg(input, output));
|
||||
} else { // deal with videos
|
||||
// reshape input to nhwc
|
||||
auto input_shape = input->shape();
|
||||
dsize_t num_batch = input->Size() / (input_shape[-3] * input_shape[-2] * input_shape[-1]);
|
||||
TensorShape new_shape({num_batch, input_shape[-3], input_shape[-2], input_shape[-1]});
|
||||
RETURN_IF_NOT_OK(input->Reshape(new_shape));
|
||||
|
||||
// split [N, H, W, C] to N [H, W, C], and center crop N [H, W, C]
|
||||
std::vector<std::shared_ptr<Tensor>> input_vector_hwc, output_vector_hwc;
|
||||
RETURN_IF_NOT_OK(BatchTensorToTensorVector(input, &input_vector_hwc));
|
||||
for (int i = 0; i < num_batch; i++) {
|
||||
std::shared_ptr<Tensor> center_crop;
|
||||
RETURN_IF_NOT_OK(CenterCropImg(input_vector_hwc[i], ¢er_crop));
|
||||
output_vector_hwc.push_back(center_crop);
|
||||
}
|
||||
|
||||
// integrate N [H, W, C] to [N, H, W, C], and reshape [..., H, W, C]
|
||||
RETURN_IF_NOT_OK(TensorVectorToBatchTensor(output_vector_hwc, output));
|
||||
// reshape output before return, only height and width are changed
|
||||
std::shared_ptr<TensorShape> output_shape_new = nullptr;
|
||||
RETURN_IF_NOT_OK(ConstructShape(input_shape, &output_shape_new));
|
||||
RETURN_IF_NOT_OK((*output)->Reshape(*output_shape_new));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void CenterCropOp::Print(std::ostream &out) const {
|
||||
out << "CenterCropOp: "
|
||||
<< "cropWidth: " << crop_wid_ << "cropHeight: " << crop_het_ << "\n";
|
||||
}
|
||||
|
||||
Status CenterCropOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
|
||||
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
|
||||
outputs.clear();
|
||||
TensorShape out = TensorShape{crop_het_, crop_wid_};
|
||||
if (inputs[0].Rank() == 2) {
|
||||
(void)outputs.emplace_back(out);
|
||||
if (inputs[0].Rank() == kMinImageRank) {
|
||||
outputs.emplace_back(out);
|
||||
}
|
||||
if (inputs[0].Rank() == 3) {
|
||||
(void)outputs.emplace_back(out.AppendDim(inputs[0][2]));
|
||||
if (inputs[0].Rank() == kDefaultImageRank) {
|
||||
outputs.emplace_back(out.AppendDim(inputs[0][kChannelIndexHWC]));
|
||||
}
|
||||
if (!outputs.empty()) {
|
||||
return Status::OK();
|
||||
if (inputs[0].Rank() > kDefaultImageRank) {
|
||||
std::shared_ptr<TensorShape> output_shape_new = nullptr;
|
||||
RETURN_IF_NOT_OK(ConstructShape(inputs[0], &output_shape_new));
|
||||
outputs.emplace_back(*output_shape_new);
|
||||
}
|
||||
return Status(StatusCode::kMDUnexpectedError,
|
||||
"CenterCrop: invalid input shape, expected 2D or 3D input, but got input dimension is:" +
|
||||
std::to_string(inputs[0].Rank()));
|
||||
if (!outputs.empty()) return Status::OK();
|
||||
return Status(
|
||||
StatusCode::kMDUnexpectedError,
|
||||
"CenterCrop: input tensor should have at least 2 dimensions, but got: " + std::to_string(inputs[0].Rank()));
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -43,6 +43,10 @@ class CenterCropOp : public TensorOp {
|
|||
std::string Name() const override { return kCenterCropOp; }
|
||||
|
||||
private:
|
||||
Status CenterCropImg(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||
|
||||
Status ConstructShape(const TensorShape &in_shape, std::shared_ptr<TensorShape> *out_shape);
|
||||
|
||||
int32_t crop_het_;
|
||||
int32_t crop_wid_;
|
||||
};
|
||||
|
|
|
@ -343,7 +343,7 @@ class CenterCrop(ImageTensorOperation):
|
|||
Raises:
|
||||
TypeError: If `size` is not of type int or sequence.
|
||||
ValueError: If `size` is less than or equal to 0.
|
||||
RuntimeError: If given tensor shape is not <H, W> or <H, W, C>.
|
||||
RuntimeError: If given tensor shape is not <H, W> or <..., H, W, C>.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
|
|
@ -613,7 +613,7 @@ class CenterCrop(ImageTensorOperation, PyTensorOperation):
|
|||
Raises:
|
||||
TypeError: If `size` is not of type integer or sequence.
|
||||
ValueError: If `size` is less than or equal to 0.
|
||||
RuntimeError: If given tensor shape is not <H, W> or <H, W, C>.
|
||||
RuntimeError: If given tensor shape is not <H, W> or <..., H, W, C>.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
|
|
@ -288,6 +288,58 @@ TEST_F(MindDataTestPipeline, TestCenterCrop) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: CenterCrop
|
||||
/// Description: Use batched dataset as video inputs
|
||||
/// Expectation: The log will print correct shape
|
||||
TEST_F(MindDataTestPipeline, TestCenterCropBatch) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCenterCropBatch.";
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 5));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Repeat operation on ds
|
||||
int32_t repeat_num = 3;
|
||||
ds = ds->Repeat(repeat_num);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
// Create a Batch operation on ds, choose batch size 5 to test high dimension input
|
||||
int32_t batch_size = 5;
|
||||
ds = ds->Batch(batch_size);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create centre crop object with square crop
|
||||
const std::vector<int32_t> crop_size{30};
|
||||
std::shared_ptr<TensorTransform> centre_out1 = std::make_shared<vision::CenterCrop>(crop_size);
|
||||
// Note: No need to check for output after calling API class constructor
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({centre_out1});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 3);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: CenterCrop op
|
||||
/// Description: Test CenterCrop op basic usage
|
||||
/// Expectation: Output is equal to the expected output
|
||||
|
|
|
@ -165,6 +165,44 @@ def test_center_crop_errors():
|
|||
str(e)
|
||||
|
||||
|
||||
def test_center_crop_high_dimensions():
|
||||
"""
|
||||
Feature: CenterCrop
|
||||
Description: Use randomly generated tensors and batched dataset as video inputs
|
||||
Expectation: Cropped images should in correct shape
|
||||
"""
|
||||
logger.info("Test CenterCrop using video inputs.")
|
||||
# use randomly generated tensor for testing
|
||||
video_frames = np.random.randint(
|
||||
0, 255, size=(32, 64, 64, 3), dtype=np.uint8)
|
||||
center_crop_op = vision.CenterCrop([32, 32])
|
||||
video_frames = center_crop_op(video_frames)
|
||||
assert video_frames.shape[1] == 32
|
||||
assert video_frames.shape[2] == 32
|
||||
|
||||
# use a batch of real image for testing
|
||||
# First dataset
|
||||
height = 200
|
||||
width = 200
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
decode_op = vision.Decode()
|
||||
center_crop_op = vision.CenterCrop([height, width])
|
||||
data1 = data1.map(operations=decode_op, input_columns=["image"])
|
||||
data1_batch = data1.batch(batch_size=2)
|
||||
|
||||
for item in data1_batch.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
original_channel = item["image"].shape[-1]
|
||||
|
||||
data1_batch = data1_batch.map(
|
||||
operations=center_crop_op, input_columns=["image"])
|
||||
|
||||
for item in data1_batch.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
shape = item["image"].shape
|
||||
assert shape[-3] == height
|
||||
assert shape[-2] == width
|
||||
assert shape[-1] == original_channel
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_center_crop_op(600, 600, plot=True)
|
||||
test_center_crop_op(300, 600)
|
||||
|
@ -172,3 +210,4 @@ if __name__ == "__main__":
|
|||
test_center_crop_md5()
|
||||
test_center_crop_comp(plot=True)
|
||||
test_crop_grayscale()
|
||||
test_center_crop_high_dimensions()
|
||||
|
|
Loading…
Reference in New Issue