fix pad ops doc and nulltpr

This commit is contained in:
fangzehua 2023-01-18 11:22:30 +08:00
parent 680f21a547
commit 5788279b3b
4 changed files with 48 additions and 37 deletions

View File

@ -24,8 +24,8 @@ mindspore.ops.PadV3
- **ValueError** - `mode` 是"constant"的同时 `paddings` 元素个数不是偶数。 - **ValueError** - `mode` 是"constant"的同时 `paddings` 元素个数不是偶数。
- **ValueError** - `mode` 是"constant"的同时 `paddings` 元素个数大于输入维度乘以2。 - **ValueError** - `mode` 是"constant"的同时 `paddings` 元素个数大于输入维度乘以2。
- **ValueError** - `mode` 是"edge"或"reflect"的同时 `paddings` 元素个数不是2、4或6。 - **ValueError** - `mode` 是"edge"或"reflect"的同时 `paddings` 元素个数不是2、4或6。
- **ValueError** - `mode` 是"edge"或"reflect" `x` 的维度是3 `paddings` 元素个数是2。 - **ValueError** - `mode` 是"edge"或"reflect" `x` 的维度是3 `paddings` 元素个数是2。
- **ValueError** - `mode` 是"edge"或"reflect" `x` 的维度是4 `paddings` 元素个数是4。 - **ValueError** - `mode` 是"edge"或"reflect" `x` 的维度是4 `paddings` 元素个数是4。
- **ValueError** - `mode` 是"edge"或"reflect"的同时 `x` 的维度小于3。 - **ValueError** - `mode` 是"edge"或"reflect"的同时 `x` 的维度小于3。
- **ValueError** - `mode` 是"edge"的同时 `x` 的维度大于5。 - **ValueError** - `mode` 是"edge"的同时 `x` 的维度大于5。
- **ValueError** - `mode` 是"reflect"的同时 `x` 的维度大于4。 - **ValueError** - `mode` 是"reflect"的同时 `x` 的维度大于4。

View File

@ -51,6 +51,43 @@ void CheckPaddingParam(const std::vector<int64_t> &paddings_shape, const std::ve
return; return;
} }
void CheckPaddingValue(const std::vector<std::pair<int64_t, int64_t>> &paddings_attr,
const std::vector<int64_t> &x_shape, const std::string &mode, const std::string &prim_name) {
int64_t size = static_cast<int64_t>(x_shape.size());
if (size < 0 || size > MAX_PADDINGS) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', the dimension of input only supports less than or equal to 5 dims, but got " << size
<< " dims";
}
for (int64_t i = 0; i < size; i++) {
if (x_shape[i] == abstract::Shape::kShapeDimAny) {
continue;
}
if (paddings_attr[i].first < 0 || paddings_attr[i].second < 0) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', all elements of paddings must be >= 0.";
}
if (mode == "SYMMETRIC") {
if (paddings_attr[i].first > static_cast<int64_t>(x_shape[i]) ||
paddings_attr[i].second > static_cast<int64_t>(x_shape[i])) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', paddings must be no greater "
"than the dimension size: ["
<< paddings_attr[i].first << "], [" << paddings_attr[i].second << "] greater than ["
<< static_cast<int64_t>(x_shape[i]) << "]";
}
} else if (mode == "REFLECT") {
if (paddings_attr[i].first >= static_cast<int64_t>(x_shape[i]) ||
paddings_attr[i].second >= static_cast<int64_t>(x_shape[i])) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', paddings must be no greater "
"than the dimension size: ["
<< paddings_attr[i].first << "], [" << paddings_attr[i].second << "] not less than ["
<< static_cast<int64_t>(x_shape[i]) << "]";
}
}
}
}
MIND_API_OPERATOR_IMPL(MirrorPad, BaseOperator); MIND_API_OPERATOR_IMPL(MirrorPad, BaseOperator);
class MirrorPadInfer : public abstract::OpInferBase { class MirrorPadInfer : public abstract::OpInferBase {
public: public:
@ -58,6 +95,9 @@ class MirrorPadInfer : public abstract::OpInferBase {
const std::vector<AbstractBasePtr> &input_args) const override { const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto input_x_shape_ptr = input_args[0]->BuildShape(); auto input_x_shape_ptr = input_args[0]->BuildShape();
MS_EXCEPTION_IF_NULL(input_x_shape_ptr); MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
auto input_x_shape = input_x_shape_ptr->cast<abstract::ShapePtr>(); auto input_x_shape = input_x_shape_ptr->cast<abstract::ShapePtr>();
@ -83,39 +123,7 @@ class MirrorPadInfer : public abstract::OpInferBase {
} }
(void)CheckAndConvertUtils::CheckInteger(kPaddingsSize, SizeToLong(paddings_attr.size()), kEqual, (void)CheckAndConvertUtils::CheckInteger(kPaddingsSize, SizeToLong(paddings_attr.size()), kEqual,
SizeToLong(x_shape.size()), prim_name); SizeToLong(x_shape.size()), prim_name);
int64_t size = static_cast<int64_t>(x_shape.size()); CheckPaddingValue(paddings_attr, x_shape, mode, prim_name);
if (size < 0 || size > MAX_PADDINGS) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', the dimension of input only supports less than or equal to 5 dims, but got "
<< size << " dims";
}
for (int64_t i = 0; i < size; i++) {
if (x_shape[i] == abstract::Shape::kShapeDimAny) {
continue;
}
if (paddings_attr[i].first < 0 || paddings_attr[i].second < 0) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', all elements of paddings must be >= 0.";
}
if (mode == "SYMMETRIC") {
if (paddings_attr[i].first > static_cast<int64_t>(x_shape[i]) ||
paddings_attr[i].second > static_cast<int64_t>(x_shape[i])) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', paddings must be no greater "
"than the dimension size: ["
<< paddings_attr[i].first << "], [" << paddings_attr[i].second << "] greater than ["
<< static_cast<int64_t>(x_shape[i]) << "]";
}
} else if (mode == "REFLECT") {
if (paddings_attr[i].first >= static_cast<int64_t>(x_shape[i]) ||
paddings_attr[i].second >= static_cast<int64_t>(x_shape[i])) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', paddings must be no greater "
"than the dimension size: ["
<< paddings_attr[i].first << "], [" << paddings_attr[i].second << "] not less than ["
<< static_cast<int64_t>(x_shape[i]) << "]";
}
}
}
std::vector<int64_t> out_shape; std::vector<int64_t> out_shape;
for (size_t i = 0; i < x_shape.size(); i++) { for (size_t i = 0; i < x_shape.size(); i++) {
// In dynamic situation , if input axis is dynamic, output axis is dynamic too. // In dynamic situation , if input axis is dynamic, output axis is dynamic too.

View File

@ -31,6 +31,9 @@ namespace {
TypePtr PaddingInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { TypePtr PaddingInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto name = primitive->name(); auto name = primitive->name();
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto context = MsContext::GetInstance(); auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice); bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);

View File

@ -4298,9 +4298,9 @@ class PadV3(Primitive):
ValueError: If `mode` is "constant", the element's number of `paddings` large than input dim * 2. ValueError: If `mode` is "constant", the element's number of `paddings` large than input dim * 2.
ValueError: If `mode` is "edge" or "reflect", the element's number of `paddings` is not 2, 4 or 6. ValueError: If `mode` is "edge" or "reflect", the element's number of `paddings` is not 2, 4 or 6.
ValueError: If `mode` is "edge" or "reflect", `x` dims equals 3, ValueError: If `mode` is "edge" or "reflect", `x` dims equals 3,
the element's number of `paddings` is 2. the element's number of `paddings` is not 2.
ValueError: If `mode` is "edge" or "reflect", `x` dims equals 4, ValueError: If `mode` is "edge" or "reflect", `x` dims equals 4,
the element's number of `paddings` is 4. the element's number of `paddings` is not 4.
ValueError: If `mode` is "edge" or "reflect", `x` dims smaller than 3. ValueError: If `mode` is "edge" or "reflect", `x` dims smaller than 3.
ValueError: If `mode` is "edge", x dims bigger than 5. ValueError: If `mode` is "edge", x dims bigger than 5.
ValueError: If `mode` is "reflect", x dims bigger than 4. ValueError: If `mode` is "reflect", x dims bigger than 4.