forked from mindspore-Ecosystem/mindspore
update mindspore/core/ops/mirror_pad.cc.
Signed-off-by: fangzehua <fangzehua1@huawei.com>
This commit is contained in:
parent
5b3f650d4a
commit
7f297125bc
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include <set>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include "ops/mirror_pad.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
@ -32,6 +33,24 @@ constexpr int64_t MAX_PADDINGS = 5;
|
|||
void MirrorPad::set_mode(const std::string &mode) { (void)AddAttr(kMode, api::MakeValue(mode)); }
|
||||
std::string MirrorPad::get_mode() const { return GetValue<std::string>(GetAttr(kMode)); }
|
||||
|
||||
void CheckPaddingParam(const std::vector<int64_t> &paddings_shape, const std::vector<int64_t> &x_shape,
|
||||
const std::string &prim_name) {
|
||||
if (paddings_shape.size() != kPaddingsSecondDimSize) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', paddings must be equal to 2 dims, but got "
|
||||
<< paddings_shape.size();
|
||||
}
|
||||
if (paddings_shape[1] != kPaddingsSecondDimSize) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', paddings must be a matrix with 2 columns, but got "
|
||||
<< paddings_shape[1];
|
||||
}
|
||||
if (static_cast<size_t>(paddings_shape[0]) != x_shape.size()) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', paddings.shape[0] must equal to input's rank, but got "
|
||||
<< paddings_shape[0];
|
||||
}
|
||||
MS_LOG(DEBUG) << "For '" << prim_name << "' padding shape: " << paddings_shape;
|
||||
return;
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(MirrorPad, BaseOperator);
|
||||
class MirrorPadInfer : public abstract::OpInferBase {
|
||||
public:
|
||||
|
@ -39,71 +58,72 @@ class MirrorPadInfer : public abstract::OpInferBase {
|
|||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto paddings = input_args[1]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(paddings);
|
||||
auto paddings_arg = CheckAndConvertUtils::CheckTensorIntValue(kPaddings, paddings, prim_name);
|
||||
std::vector<std::pair<int64_t, int64_t>> paddings_attr;
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
// ToSupport Dynamic rank
|
||||
if (IsDynamicRank(x_shape)) {
|
||||
auto input_x_shape_ptr = input_args[0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
|
||||
auto input_x_shape = input_x_shape_ptr->cast<abstract::ShapePtr>();
|
||||
// Dynamic rank process.
|
||||
if (IsDynamicRank(input_x_shape->shape())) {
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{-2});
|
||||
}
|
||||
auto paddings = input_args[1]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(paddings);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
// if shape of x is determined and padding value is unknown, return a all -1 shape
|
||||
if (paddings->isa<AnyValue>() || paddings->isa<None>()) {
|
||||
return std::make_shared<abstract::Shape>(ShapeVector(x_shape.size(), -1));
|
||||
}
|
||||
auto paddings_arg = CheckAndConvertUtils::CheckTensorIntValue(kPaddings, paddings, prim_name);
|
||||
std::vector<std::pair<int64_t, int64_t>> paddings_attr;
|
||||
|
||||
auto paddings_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto mode = GetValue<std::string>(primitive->GetAttr(kMode));
|
||||
if (paddings_shape.size() != kPaddingsSecondDimSize) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', paddings must be equal to 2 dims, but got "
|
||||
<< paddings_shape.size();
|
||||
}
|
||||
if (paddings_shape[1] != kPaddingsSecondDimSize) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', paddings must be a matrix with 2 columns, but got "
|
||||
<< paddings_shape[1];
|
||||
}
|
||||
if (static_cast<size_t>(paddings_shape[0]) != x_shape.size()) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', paddings.shape[0] must equal to input's rank, but got "
|
||||
<< paddings_shape[0];
|
||||
}
|
||||
for (size_t i = 0; i < paddings_arg.size(); i = i + kPaddingsSecondDimSize) {
|
||||
CheckPaddingParam(paddings_shape, x_shape, prim_name);
|
||||
for (size_t i = 0; i < paddings_arg.size(); i = i + static_cast<size_t>(kPaddingsSecondDimSize)) {
|
||||
paddings_attr.push_back(std::make_pair(paddings_arg[i], paddings_arg[i + 1]));
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckInteger(kPaddingsSize, SizeToLong(paddings_attr.size()), kEqual,
|
||||
SizeToLong(x_shape.size()), prim_name);
|
||||
auto input_x_shape_ptr = input_args[0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
|
||||
if (input_x_shape_ptr->IsDynamic()) {
|
||||
return input_args[0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||
}
|
||||
int64_t size = SizeToLong(x_shape.size());
|
||||
int64_t size = static_cast<int64_t>(x_shape.size());
|
||||
if (size < 0 || size > MAX_PADDINGS) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||
<< "', the dimension of input only supports less than or equal to 5 dims, but got "
|
||||
<< size << " dims";
|
||||
}
|
||||
for (size_t i = 0; i < LongToSize(size); i++) {
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
if (x_shape[i] == -1) {
|
||||
continue;
|
||||
}
|
||||
if (paddings_attr[i].first < 0 || paddings_attr[i].second < 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', all elements of paddings must be >= 0.";
|
||||
}
|
||||
if (mode == "SYMMETRIC") {
|
||||
if (paddings_attr[i].first > x_shape[i] || paddings_attr[i].second > x_shape[i]) {
|
||||
if (paddings_attr[i].first > static_cast<int64_t>(x_shape[i]) ||
|
||||
paddings_attr[i].second > static_cast<int64_t>(x_shape[i])) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||
<< "', paddings must be no greater "
|
||||
"than the dimension size: ["
|
||||
<< paddings_attr[i].first << "], [" << paddings_attr[i].second << "] greater than ["
|
||||
<< x_shape[i] << "]";
|
||||
<< static_cast<int64_t>(x_shape[i]) << "]";
|
||||
}
|
||||
} else if (mode == "REFLECT") {
|
||||
if (paddings_attr[i].first >= x_shape[i] || paddings_attr[i].second >= x_shape[i]) {
|
||||
if (paddings_attr[i].first >= static_cast<int64_t>(x_shape[i]) ||
|
||||
paddings_attr[i].second >= static_cast<int64_t>(x_shape[i])) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||
<< "', paddings must be no greater "
|
||||
"than the dimension size: ["
|
||||
<< paddings_attr[i].first << "], [" << paddings_attr[i].second << "] not less than ["
|
||||
<< x_shape[i] << "]";
|
||||
<< static_cast<int64_t>(x_shape[i]) << "]";
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<int64_t> out_shape;
|
||||
for (size_t i = 0; i < x_shape.size(); i++) {
|
||||
(void)out_shape.emplace_back(x_shape[i] + paddings_attr[i].first + paddings_attr[i].second);
|
||||
// In dynamic situation , if input axis is dynamic, output axis is dynamic too.
|
||||
if (x_shape[i] == -1) {
|
||||
(void)out_shape.emplace_back(-1);
|
||||
} else {
|
||||
(void)out_shape.emplace_back(x_shape[i] + paddings_attr[i].first + paddings_attr[i].second);
|
||||
}
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue