update mindspore/core/ops/mirror_pad.cc.

Signed-off-by: fangzehua <fangzehua1@huawei.com>
This commit is contained in:
fangzehua 2022-12-28 12:32:07 +00:00 committed by Gitee
parent 5b3f650d4a
commit 7f297125bc
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 52 additions and 32 deletions

View File

@ -16,6 +16,7 @@
#include <set>
#include <utility>
#include <string>
#include "ops/mirror_pad.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
@ -32,6 +33,24 @@ constexpr int64_t MAX_PADDINGS = 5;
void MirrorPad::set_mode(const std::string &mode) { (void)AddAttr(kMode, api::MakeValue(mode)); }
std::string MirrorPad::get_mode() const { return GetValue<std::string>(GetAttr(kMode)); }
void CheckPaddingParam(const std::vector<int64_t> &paddings_shape, const std::vector<int64_t> &x_shape,
const std::string &prim_name) {
if (paddings_shape.size() != kPaddingsSecondDimSize) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', paddings must be equal to 2 dims, but got "
<< paddings_shape.size();
}
if (paddings_shape[1] != kPaddingsSecondDimSize) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', paddings must be a matrix with 2 columns, but got "
<< paddings_shape[1];
}
if (static_cast<size_t>(paddings_shape[0]) != x_shape.size()) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', paddings.shape[0] must equal to input's rank, but got "
<< paddings_shape[0];
}
MS_LOG(DEBUG) << "For '" << prim_name << "' padding shape: " << paddings_shape;
return;
}
MIND_API_OPERATOR_IMPL(MirrorPad, BaseOperator);
class MirrorPadInfer : public abstract::OpInferBase {
public:
@ -39,71 +58,72 @@ class MirrorPadInfer : public abstract::OpInferBase {
const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto paddings = input_args[1]->BuildValue();
MS_EXCEPTION_IF_NULL(paddings);
auto paddings_arg = CheckAndConvertUtils::CheckTensorIntValue(kPaddings, paddings, prim_name);
std::vector<std::pair<int64_t, int64_t>> paddings_attr;
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
// ToSupport Dynamic rank
if (IsDynamicRank(x_shape)) {
auto input_x_shape_ptr = input_args[0]->BuildShape();
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
auto input_x_shape = input_x_shape_ptr->cast<abstract::ShapePtr>();
// Dynamic rank process.
if (IsDynamicRank(input_x_shape->shape())) {
return std::make_shared<abstract::Shape>(std::vector<int64_t>{-2});
}
auto paddings = input_args[1]->BuildValue();
MS_EXCEPTION_IF_NULL(paddings);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
// if shape of x is determined and padding value is unknown, return a all -1 shape
if (paddings->isa<AnyValue>() || paddings->isa<None>()) {
return std::make_shared<abstract::Shape>(ShapeVector(x_shape.size(), -1));
}
auto paddings_arg = CheckAndConvertUtils::CheckTensorIntValue(kPaddings, paddings, prim_name);
std::vector<std::pair<int64_t, int64_t>> paddings_attr;
auto paddings_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto mode = GetValue<std::string>(primitive->GetAttr(kMode));
if (paddings_shape.size() != kPaddingsSecondDimSize) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', paddings must be equal to 2 dims, but got "
<< paddings_shape.size();
}
if (paddings_shape[1] != kPaddingsSecondDimSize) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', paddings must be a matrix with 2 columns, but got "
<< paddings_shape[1];
}
if (static_cast<size_t>(paddings_shape[0]) != x_shape.size()) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', paddings.shape[0] must equal to input's rank, but got "
<< paddings_shape[0];
}
for (size_t i = 0; i < paddings_arg.size(); i = i + kPaddingsSecondDimSize) {
CheckPaddingParam(paddings_shape, x_shape, prim_name);
for (size_t i = 0; i < paddings_arg.size(); i = i + static_cast<size_t>(kPaddingsSecondDimSize)) {
paddings_attr.push_back(std::make_pair(paddings_arg[i], paddings_arg[i + 1]));
}
(void)CheckAndConvertUtils::CheckInteger(kPaddingsSize, SizeToLong(paddings_attr.size()), kEqual,
SizeToLong(x_shape.size()), prim_name);
auto input_x_shape_ptr = input_args[0]->BuildShape();
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
if (input_x_shape_ptr->IsDynamic()) {
return input_args[0]->BuildShape()->cast<abstract::ShapePtr>();
}
int64_t size = SizeToLong(x_shape.size());
int64_t size = static_cast<int64_t>(x_shape.size());
if (size < 0 || size > MAX_PADDINGS) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', the dimension of input only supports less than or equal to 5 dims, but got "
<< size << " dims";
}
for (size_t i = 0; i < LongToSize(size); i++) {
for (int64_t i = 0; i < size; i++) {
if (x_shape[i] == -1) {
continue;
}
if (paddings_attr[i].first < 0 || paddings_attr[i].second < 0) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', all elements of paddings must be >= 0.";
}
if (mode == "SYMMETRIC") {
if (paddings_attr[i].first > x_shape[i] || paddings_attr[i].second > x_shape[i]) {
if (paddings_attr[i].first > static_cast<int64_t>(x_shape[i]) ||
paddings_attr[i].second > static_cast<int64_t>(x_shape[i])) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', paddings must be no greater "
"than the dimension size: ["
<< paddings_attr[i].first << "], [" << paddings_attr[i].second << "] greater than ["
<< x_shape[i] << "]";
<< static_cast<int64_t>(x_shape[i]) << "]";
}
} else if (mode == "REFLECT") {
if (paddings_attr[i].first >= x_shape[i] || paddings_attr[i].second >= x_shape[i]) {
if (paddings_attr[i].first >= static_cast<int64_t>(x_shape[i]) ||
paddings_attr[i].second >= static_cast<int64_t>(x_shape[i])) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', paddings must be no greater "
"than the dimension size: ["
<< paddings_attr[i].first << "], [" << paddings_attr[i].second << "] not less than ["
<< x_shape[i] << "]";
<< static_cast<int64_t>(x_shape[i]) << "]";
}
}
}
std::vector<int64_t> out_shape;
for (size_t i = 0; i < x_shape.size(); i++) {
(void)out_shape.emplace_back(x_shape[i] + paddings_attr[i].first + paddings_attr[i].second);
// In dynamic situation , if input axis is dynamic, output axis is dynamic too.
if (x_shape[i] == -1) {
(void)out_shape.emplace_back(-1);
} else {
(void)out_shape.emplace_back(x_shape[i] + paddings_attr[i].first + paddings_attr[i].second);
}
}
return std::make_shared<abstract::Shape>(out_shape);
}