forked from mindspore-Ecosystem/mindspore
fix pad ops doc and nulltpr
This commit is contained in:
parent
680f21a547
commit
5788279b3b
|
@ -24,8 +24,8 @@ mindspore.ops.PadV3
|
||||||
- **ValueError** - `mode` 是"constant"的同时 `paddings` 元素个数不是偶数。
|
- **ValueError** - `mode` 是"constant"的同时 `paddings` 元素个数不是偶数。
|
||||||
- **ValueError** - `mode` 是"constant"的同时 `paddings` 元素个数大于输入维度乘以2。
|
- **ValueError** - `mode` 是"constant"的同时 `paddings` 元素个数大于输入维度乘以2。
|
||||||
- **ValueError** - `mode` 是"edge"或"reflect"的同时 `paddings` 元素个数不是2、4或6。
|
- **ValueError** - `mode` 是"edge"或"reflect"的同时 `paddings` 元素个数不是2、4或6。
|
||||||
- **ValueError** - `mode` 是"edge"或"reflect", `x` 的维度是3, `paddings` 元素个数是2。
|
- **ValueError** - `mode` 是"edge"或"reflect", `x` 的维度是3, `paddings` 元素个数不是2。
|
||||||
- **ValueError** - `mode` 是"edge"或"reflect", `x` 的维度是4, `paddings` 元素个数是4。
|
- **ValueError** - `mode` 是"edge"或"reflect", `x` 的维度是4, `paddings` 元素个数不是4。
|
||||||
- **ValueError** - `mode` 是"edge"或"reflect"的同时 `x` 的维度小于3。
|
- **ValueError** - `mode` 是"edge"或"reflect"的同时 `x` 的维度小于3。
|
||||||
- **ValueError** - `mode` 是"edge"的同时 `x` 的维度大于5。
|
- **ValueError** - `mode` 是"edge"的同时 `x` 的维度大于5。
|
||||||
- **ValueError** - `mode` 是"reflect"的同时 `x` 的维度大于4。
|
- **ValueError** - `mode` 是"reflect"的同时 `x` 的维度大于4。
|
||||||
|
|
|
@ -51,6 +51,43 @@ void CheckPaddingParam(const std::vector<int64_t> &paddings_shape, const std::ve
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CheckPaddingValue(const std::vector<std::pair<int64_t, int64_t>> &paddings_attr,
|
||||||
|
const std::vector<int64_t> &x_shape, const std::string &mode, const std::string &prim_name) {
|
||||||
|
int64_t size = static_cast<int64_t>(x_shape.size());
|
||||||
|
if (size < 0 || size > MAX_PADDINGS) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||||
|
<< "', the dimension of input only supports less than or equal to 5 dims, but got " << size
|
||||||
|
<< " dims";
|
||||||
|
}
|
||||||
|
for (int64_t i = 0; i < size; i++) {
|
||||||
|
if (x_shape[i] == abstract::Shape::kShapeDimAny) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (paddings_attr[i].first < 0 || paddings_attr[i].second < 0) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', all elements of paddings must be >= 0.";
|
||||||
|
}
|
||||||
|
if (mode == "SYMMETRIC") {
|
||||||
|
if (paddings_attr[i].first > static_cast<int64_t>(x_shape[i]) ||
|
||||||
|
paddings_attr[i].second > static_cast<int64_t>(x_shape[i])) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||||
|
<< "', paddings must be no greater "
|
||||||
|
"than the dimension size: ["
|
||||||
|
<< paddings_attr[i].first << "], [" << paddings_attr[i].second << "] greater than ["
|
||||||
|
<< static_cast<int64_t>(x_shape[i]) << "]";
|
||||||
|
}
|
||||||
|
} else if (mode == "REFLECT") {
|
||||||
|
if (paddings_attr[i].first >= static_cast<int64_t>(x_shape[i]) ||
|
||||||
|
paddings_attr[i].second >= static_cast<int64_t>(x_shape[i])) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||||
|
<< "', paddings must be no greater "
|
||||||
|
"than the dimension size: ["
|
||||||
|
<< paddings_attr[i].first << "], [" << paddings_attr[i].second << "] not less than ["
|
||||||
|
<< static_cast<int64_t>(x_shape[i]) << "]";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
MIND_API_OPERATOR_IMPL(MirrorPad, BaseOperator);
|
MIND_API_OPERATOR_IMPL(MirrorPad, BaseOperator);
|
||||||
class MirrorPadInfer : public abstract::OpInferBase {
|
class MirrorPadInfer : public abstract::OpInferBase {
|
||||||
public:
|
public:
|
||||||
|
@ -58,6 +95,9 @@ class MirrorPadInfer : public abstract::OpInferBase {
|
||||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto prim_name = primitive->name();
|
auto prim_name = primitive->name();
|
||||||
|
for (const auto &item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
auto input_x_shape_ptr = input_args[0]->BuildShape();
|
auto input_x_shape_ptr = input_args[0]->BuildShape();
|
||||||
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
|
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
|
||||||
auto input_x_shape = input_x_shape_ptr->cast<abstract::ShapePtr>();
|
auto input_x_shape = input_x_shape_ptr->cast<abstract::ShapePtr>();
|
||||||
|
@ -83,39 +123,7 @@ class MirrorPadInfer : public abstract::OpInferBase {
|
||||||
}
|
}
|
||||||
(void)CheckAndConvertUtils::CheckInteger(kPaddingsSize, SizeToLong(paddings_attr.size()), kEqual,
|
(void)CheckAndConvertUtils::CheckInteger(kPaddingsSize, SizeToLong(paddings_attr.size()), kEqual,
|
||||||
SizeToLong(x_shape.size()), prim_name);
|
SizeToLong(x_shape.size()), prim_name);
|
||||||
int64_t size = static_cast<int64_t>(x_shape.size());
|
CheckPaddingValue(paddings_attr, x_shape, mode, prim_name);
|
||||||
if (size < 0 || size > MAX_PADDINGS) {
|
|
||||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
|
||||||
<< "', the dimension of input only supports less than or equal to 5 dims, but got "
|
|
||||||
<< size << " dims";
|
|
||||||
}
|
|
||||||
for (int64_t i = 0; i < size; i++) {
|
|
||||||
if (x_shape[i] == abstract::Shape::kShapeDimAny) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (paddings_attr[i].first < 0 || paddings_attr[i].second < 0) {
|
|
||||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', all elements of paddings must be >= 0.";
|
|
||||||
}
|
|
||||||
if (mode == "SYMMETRIC") {
|
|
||||||
if (paddings_attr[i].first > static_cast<int64_t>(x_shape[i]) ||
|
|
||||||
paddings_attr[i].second > static_cast<int64_t>(x_shape[i])) {
|
|
||||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
|
||||||
<< "', paddings must be no greater "
|
|
||||||
"than the dimension size: ["
|
|
||||||
<< paddings_attr[i].first << "], [" << paddings_attr[i].second << "] greater than ["
|
|
||||||
<< static_cast<int64_t>(x_shape[i]) << "]";
|
|
||||||
}
|
|
||||||
} else if (mode == "REFLECT") {
|
|
||||||
if (paddings_attr[i].first >= static_cast<int64_t>(x_shape[i]) ||
|
|
||||||
paddings_attr[i].second >= static_cast<int64_t>(x_shape[i])) {
|
|
||||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
|
||||||
<< "', paddings must be no greater "
|
|
||||||
"than the dimension size: ["
|
|
||||||
<< paddings_attr[i].first << "], [" << paddings_attr[i].second << "] not less than ["
|
|
||||||
<< static_cast<int64_t>(x_shape[i]) << "]";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::vector<int64_t> out_shape;
|
std::vector<int64_t> out_shape;
|
||||||
for (size_t i = 0; i < x_shape.size(); i++) {
|
for (size_t i = 0; i < x_shape.size(); i++) {
|
||||||
// In dynamic situation , if input axis is dynamic, output axis is dynamic too.
|
// In dynamic situation , if input axis is dynamic, output axis is dynamic too.
|
||||||
|
|
|
@ -31,6 +31,9 @@ namespace {
|
||||||
TypePtr PaddingInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr PaddingInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto name = primitive->name();
|
auto name = primitive->name();
|
||||||
|
for (const auto &item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
auto context = MsContext::GetInstance();
|
auto context = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
|
bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
|
||||||
|
|
|
@ -4298,9 +4298,9 @@ class PadV3(Primitive):
|
||||||
ValueError: If `mode` is "constant", the element's number of `paddings` large than input dim * 2.
|
ValueError: If `mode` is "constant", the element's number of `paddings` large than input dim * 2.
|
||||||
ValueError: If `mode` is "edge" or "reflect", the element's number of `paddings` is not 2, 4 or 6.
|
ValueError: If `mode` is "edge" or "reflect", the element's number of `paddings` is not 2, 4 or 6.
|
||||||
ValueError: If `mode` is "edge" or "reflect", `x` dims equals 3,
|
ValueError: If `mode` is "edge" or "reflect", `x` dims equals 3,
|
||||||
the element's number of `paddings` is 2.
|
the element's number of `paddings` is not 2.
|
||||||
ValueError: If `mode` is "edge" or "reflect", `x` dims equals 4,
|
ValueError: If `mode` is "edge" or "reflect", `x` dims equals 4,
|
||||||
the element's number of `paddings` is 4.
|
the element's number of `paddings` is not 4.
|
||||||
ValueError: If `mode` is "edge" or "reflect", `x` dims smaller than 3.
|
ValueError: If `mode` is "edge" or "reflect", `x` dims smaller than 3.
|
||||||
ValueError: If `mode` is "edge", x dims bigger than 5.
|
ValueError: If `mode` is "edge", x dims bigger than 5.
|
||||||
ValueError: If `mode` is "reflect", x dims bigger than 4.
|
ValueError: If `mode` is "reflect", x dims bigger than 4.
|
||||||
|
|
Loading…
Reference in New Issue