forked from mindspore-Ecosystem/mindspore
Fix max function length of CutMixBatchOp::Compute
Signed-off-by: alex-yuyue <yue.yu1@huawei.com>
This commit is contained in:
parent
d35939603a
commit
b2a0265784
|
@ -48,12 +48,10 @@ void CutMixBatchOp::GetCropBox(int height, int width, float lam, int *x, int *y,
|
|||
*crop_height = std::clamp(y2 - *y, 1, height - 1);
|
||||
}
|
||||
|
||||
Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
||||
Status CutMixBatchOp::ValidateCutMixBatch(const TensorRow &input) {
|
||||
if (input.size() < 2) {
|
||||
RETURN_STATUS_UNEXPECTED("CutMixBatch: both image and label columns are required.");
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<Tensor>> images;
|
||||
std::vector<int64_t> image_shape = input.at(0)->shape().AsVector();
|
||||
std::vector<int64_t> label_shape = input.at(1)->shape().AsVector();
|
||||
|
||||
|
@ -80,14 +78,98 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
|||
RETURN_STATUS_UNEXPECTED("CutMixBatch: image doesn't match the NHWC format.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CutMixBatchOp::ComputeImage(const TensorRow &input, const int64_t rand_indx_i, const float lam, float *label_lam,
|
||||
std::shared_ptr<Tensor> *image_i) {
|
||||
std::vector<int64_t> image_shape = input.at(0)->shape().AsVector();
|
||||
int x, y, crop_width, crop_height;
|
||||
// Get a random image
|
||||
TensorShape remaining({-1});
|
||||
uchar *start_addr_of_index = nullptr;
|
||||
std::shared_ptr<Tensor> rand_image;
|
||||
|
||||
RETURN_IF_NOT_OK(input.at(0)->StartAddrOfIndex({rand_indx_i, 0, 0, 0}, &start_addr_of_index, &remaining));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(TensorShape({image_shape[1], image_shape[2], image_shape[3]}),
|
||||
input.at(0)->type(), start_addr_of_index, &rand_image));
|
||||
|
||||
// Compute image
|
||||
if (image_batch_format_ == ImageBatchFormat::kNHWC) {
|
||||
// NHWC Format
|
||||
GetCropBox(static_cast<int32_t>(image_shape[1]), static_cast<int32_t>(image_shape[2]), lam, &x, &y, &crop_width,
|
||||
&crop_height);
|
||||
std::shared_ptr<Tensor> cropped;
|
||||
RETURN_IF_NOT_OK(Crop(rand_image, &cropped, x, y, crop_width, crop_height));
|
||||
RETURN_IF_NOT_OK(MaskWithTensor(cropped, image_i, x, y, crop_width, crop_height, ImageFormat::HWC));
|
||||
*label_lam = 1 - (crop_width * crop_height / static_cast<float>(image_shape[1] * image_shape[2]));
|
||||
} else {
|
||||
// NCHW Format
|
||||
GetCropBox(static_cast<int32_t>(image_shape[2]), static_cast<int32_t>(image_shape[3]), lam, &x, &y, &crop_width,
|
||||
&crop_height);
|
||||
std::vector<std::shared_ptr<Tensor>> channels; // A vector holding channels of the CHW image
|
||||
std::vector<std::shared_ptr<Tensor>> cropped_channels; // A vector holding the channels of the cropped CHW
|
||||
RETURN_IF_NOT_OK(BatchTensorToTensorVector(rand_image, &channels));
|
||||
for (auto channel : channels) {
|
||||
// Call crop for each single channel
|
||||
std::shared_ptr<Tensor> cropped_channel;
|
||||
RETURN_IF_NOT_OK(Crop(channel, &cropped_channel, x, y, crop_width, crop_height));
|
||||
cropped_channels.push_back(cropped_channel);
|
||||
}
|
||||
std::shared_ptr<Tensor> cropped;
|
||||
// Merge channels to a single tensor
|
||||
RETURN_IF_NOT_OK(TensorVectorToBatchTensor(cropped_channels, &cropped));
|
||||
|
||||
RETURN_IF_NOT_OK(MaskWithTensor(cropped, image_i, x, y, crop_width, crop_height, ImageFormat::CHW));
|
||||
*label_lam = 1 - (crop_width * crop_height / static_cast<float>(image_shape[2] * image_shape[3]));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CutMixBatchOp::ComputeLabel(const TensorRow &input, const int64_t rand_indx_i, const int64_t index_i,
|
||||
const int64_t row_labels, const int64_t num_classes,
|
||||
const std::size_t label_shape_size, const float label_lam,
|
||||
std::shared_ptr<Tensor> *out_labels) {
|
||||
// Compute labels
|
||||
for (int64_t j = 0; j < row_labels; j++) {
|
||||
for (int64_t k = 0; k < num_classes; k++) {
|
||||
std::vector<int64_t> first_index = label_shape_size == 3 ? std::vector{index_i, j, k} : std::vector{index_i, k};
|
||||
std::vector<int64_t> second_index =
|
||||
label_shape_size == 3 ? std::vector{rand_indx_i, j, k} : std::vector{rand_indx_i, k};
|
||||
if (input.at(1)->type().IsSignedInt()) {
|
||||
int64_t first_value, second_value;
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, first_index));
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, second_index));
|
||||
RETURN_IF_NOT_OK(
|
||||
(*out_labels)->SetItemAt(first_index, label_lam * first_value + (1 - label_lam) * second_value));
|
||||
} else {
|
||||
uint64_t first_value, second_value;
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, first_index));
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, second_index));
|
||||
RETURN_IF_NOT_OK(
|
||||
(*out_labels)->SetItemAt(first_index, label_lam * first_value + (1 - label_lam) * second_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
||||
IO_CHECK_VECTOR(input, output);
|
||||
RETURN_IF_NOT_OK(ValidateCutMixBatch(input));
|
||||
std::vector<int64_t> image_shape = input.at(0)->shape().AsVector();
|
||||
std::vector<int64_t> label_shape = input.at(1)->shape().AsVector();
|
||||
|
||||
// Move images into a vector of Tensors
|
||||
std::vector<std::shared_ptr<Tensor>> images;
|
||||
RETURN_IF_NOT_OK(BatchTensorToTensorVector(input.at(0), &images));
|
||||
|
||||
// Calculate random labels
|
||||
std::vector<int64_t> rand_indx;
|
||||
for (int64_t i = 0; i < images.size(); i++) rand_indx.push_back(i);
|
||||
std::shuffle(rand_indx.begin(), rand_indx.end(), rnd_);
|
||||
|
||||
std::gamma_distribution<float> gamma_distribution(alpha_, 1);
|
||||
std::uniform_real_distribution<double> uniform_distribution(0.0, 1.0);
|
||||
|
||||
|
@ -107,69 +189,12 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
|||
float lam = x1 / (x1 + x2);
|
||||
double random_number = uniform_distribution(rnd_);
|
||||
if (random_number < prob_) {
|
||||
int x, y, crop_width, crop_height;
|
||||
float label_lam; // lambda used for labels
|
||||
|
||||
// Get a random image
|
||||
TensorShape remaining({-1});
|
||||
uchar *start_addr_of_index = nullptr;
|
||||
std::shared_ptr<Tensor> rand_image;
|
||||
RETURN_IF_NOT_OK(input.at(0)->StartAddrOfIndex({rand_indx[i], 0, 0, 0}, &start_addr_of_index, &remaining));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(TensorShape({image_shape[1], image_shape[2], image_shape[3]}),
|
||||
input.at(0)->type(), start_addr_of_index, &rand_image));
|
||||
|
||||
// Compute image
|
||||
if (image_batch_format_ == ImageBatchFormat::kNHWC) {
|
||||
// NHWC Format
|
||||
GetCropBox(static_cast<int32_t>(image_shape[1]), static_cast<int32_t>(image_shape[2]), lam, &x, &y, &crop_width,
|
||||
&crop_height);
|
||||
std::shared_ptr<Tensor> cropped;
|
||||
RETURN_IF_NOT_OK(Crop(rand_image, &cropped, x, y, crop_width, crop_height));
|
||||
RETURN_IF_NOT_OK(MaskWithTensor(cropped, &images[i], x, y, crop_width, crop_height, ImageFormat::HWC));
|
||||
label_lam = 1 - (crop_width * crop_height / static_cast<float>(image_shape[1] * image_shape[2]));
|
||||
} else {
|
||||
// NCHW Format
|
||||
GetCropBox(static_cast<int32_t>(image_shape[2]), static_cast<int32_t>(image_shape[3]), lam, &x, &y, &crop_width,
|
||||
&crop_height);
|
||||
std::vector<std::shared_ptr<Tensor>> channels; // A vector holding channels of the CHW image
|
||||
std::vector<std::shared_ptr<Tensor>> cropped_channels; // A vector holding the channels of the cropped CHW
|
||||
RETURN_IF_NOT_OK(BatchTensorToTensorVector(rand_image, &channels));
|
||||
for (auto channel : channels) {
|
||||
// Call crop for each single channel
|
||||
std::shared_ptr<Tensor> cropped_channel;
|
||||
RETURN_IF_NOT_OK(Crop(channel, &cropped_channel, x, y, crop_width, crop_height));
|
||||
cropped_channels.push_back(cropped_channel);
|
||||
}
|
||||
std::shared_ptr<Tensor> cropped;
|
||||
// Merge channels to a single tensor
|
||||
RETURN_IF_NOT_OK(TensorVectorToBatchTensor(cropped_channels, &cropped));
|
||||
|
||||
RETURN_IF_NOT_OK(MaskWithTensor(cropped, &images[i], x, y, crop_width, crop_height, ImageFormat::CHW));
|
||||
label_lam = 1 - (crop_width * crop_height / static_cast<float>(image_shape[2] * image_shape[3]));
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ComputeImage(input, rand_indx[i], lam, &label_lam, &images[i]));
|
||||
// Compute labels
|
||||
|
||||
for (int64_t j = 0; j < row_labels; j++) {
|
||||
for (int64_t k = 0; k < num_classes; k++) {
|
||||
std::vector<int64_t> first_index = label_shape.size() == 3 ? std::vector{i, j, k} : std::vector{i, k};
|
||||
std::vector<int64_t> second_index =
|
||||
label_shape.size() == 3 ? std::vector{rand_indx[i], j, k} : std::vector{rand_indx[i], k};
|
||||
if (input.at(1)->type().IsSignedInt()) {
|
||||
int64_t first_value, second_value;
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, first_index));
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, second_index));
|
||||
RETURN_IF_NOT_OK(
|
||||
out_labels->SetItemAt(first_index, label_lam * first_value + (1 - label_lam) * second_value));
|
||||
} else {
|
||||
uint64_t first_value, second_value;
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, first_index));
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, second_index));
|
||||
RETURN_IF_NOT_OK(
|
||||
out_labels->SetItemAt(first_index, label_lam * first_value + (1 - label_lam) * second_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(
|
||||
ComputeLabel(input, rand_indx[i], i, row_labels, num_classes, label_shape.size(), label_lam, &out_labels));
|
||||
}
|
||||
}
|
||||
std::shared_ptr<Tensor> out_images;
|
||||
|
|
|
@ -36,11 +36,40 @@ class CutMixBatchOp : public TensorOp {
|
|||
void Print(std::ostream &out) const override;
|
||||
|
||||
void GetCropBox(int width, int height, float lam, int *x, int *y, int *crop_width, int *crop_height);
|
||||
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||
|
||||
std::string Name() const override { return kCutMixBatchOp; }
|
||||
|
||||
private:
|
||||
/// \brief Helper function used in Compute to validate the input TensorRow.
|
||||
/// \param[in] input Input TensorRow of CutMixBatchOp
|
||||
/// \returns Status
|
||||
Status ValidateCutMixBatch(const TensorRow &input);
|
||||
|
||||
/// \brief Helper function used in Compute to compute each image.
|
||||
/// \param[in] input Input TensorRow of CutMixBatchOp.
|
||||
/// \param[in] rand_indx_i The i-th generated random index as the start address of the input image.
|
||||
/// \param[in] lam A random variable follow Beta distribution, used in GetCropBox.
|
||||
/// \param[in] label_lam Lambda used for labels, will be updated after computing each image.
|
||||
/// \param[in] image_i The result of the i-th computed image.
|
||||
/// \returns Status
|
||||
Status ComputeImage(const TensorRow &input, const int64_t rand_indx_i, const float lam, float *label_lam,
|
||||
std::shared_ptr<Tensor> *image_i);
|
||||
|
||||
/// \brief Helper function used in Compute to compute each label corresponding to each image.
|
||||
/// \param[in] input Input TensorRow of CutMixBatchOp.
|
||||
/// \param[in] rand_indx_i The i-th generated random index as the start address of the input image.
|
||||
/// \param[in] index_i The i-th label to be generated, corresponding to the i-th computed image.
|
||||
/// \param[in] row_labels Number of rows of the label.
|
||||
/// \param[in] num_classes Number of class of the label.
|
||||
/// \param[in] label_shape_size The size of the label shape from input TensorRow.
|
||||
/// \param[in] label_lam Lambda used for setting the location.
|
||||
/// \param[in] out_labels The output of the i-th label, corresponding to the i-th computed image.
|
||||
/// \returns Status
|
||||
Status ComputeLabel(const TensorRow &input, const int64_t rand_indx_i, const int64_t index_i,
|
||||
const int64_t row_labels, const int64_t num_classes, const std::size_t label_shape_size,
|
||||
const float label_lam, std::shared_ptr<Tensor> *out_labels);
|
||||
float alpha_;
|
||||
float prob_;
|
||||
ImageBatchFormat image_batch_format_;
|
||||
|
|
Loading…
Reference in New Issue