fix pad ops doc and nulltpr

This commit is contained in:
fangzehua 2023-01-18 11:22:30 +08:00
parent 680f21a547
commit 5788279b3b
4 changed files with 48 additions and 37 deletions

View File

@ -24,8 +24,8 @@ mindspore.ops.PadV3
- **ValueError** - `mode` 是"constant"的同时 `paddings` 元素个数不是偶数。 - **ValueError** - `mode` 是"constant"的同时 `paddings` 元素个数不是偶数。
- **ValueError** - `mode` 是"constant"的同时 `paddings` 元素个数大于输入维度乘以2。 - **ValueError** - `mode` 是"constant"的同时 `paddings` 元素个数大于输入维度乘以2。
- **ValueError** - `mode` 是"edge"或"reflect"的同时 `paddings` 元素个数不是2、4或6。 - **ValueError** - `mode` 是"edge"或"reflect"的同时 `paddings` 元素个数不是2、4或6。
- **ValueError** - `mode` 是"edge"或"reflect" `x` 的维度是3 `paddings` 元素个数是2。 - **ValueError** - `mode` 是"edge"或"reflect" `x` 的维度是3 `paddings` 元素个数是2。
- **ValueError** - `mode` 是"edge"或"reflect" `x` 的维度是4 `paddings` 元素个数是4。 - **ValueError** - `mode` 是"edge"或"reflect" `x` 的维度是4 `paddings` 元素个数是4。
- **ValueError** - `mode` 是"edge"或"reflect"的同时 `x` 的维度小于3。 - **ValueError** - `mode` 是"edge"或"reflect"的同时 `x` 的维度小于3。
- **ValueError** - `mode` 是"edge"的同时 `x` 的维度大于5。 - **ValueError** - `mode` 是"edge"的同时 `x` 的维度大于5。
- **ValueError** - `mode` 是"reflect"的同时 `x` 的维度大于4。 - **ValueError** - `mode` 是"reflect"的同时 `x` 的维度大于4。

View File

@ -51,43 +51,13 @@ void CheckPaddingParam(const std::vector<int64_t> &paddings_shape, const std::ve
return; return;
} }
MIND_API_OPERATOR_IMPL(MirrorPad, BaseOperator); void CheckPaddingValue(const std::vector<std::pair<int64_t, int64_t>> &paddings_attr,
class MirrorPadInfer : public abstract::OpInferBase { const std::vector<int64_t> &x_shape, const std::string &mode, const std::string &prim_name) {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto input_x_shape_ptr = input_args[0]->BuildShape();
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
auto input_x_shape = input_x_shape_ptr->cast<abstract::ShapePtr>();
// Dynamic rank process.
if (IsDynamicRank(input_x_shape->shape())) {
return std::make_shared<abstract::Shape>(ShapeVector{abstract::Shape::kShapeRankAny});
}
auto paddings = input_args[1]->BuildValue();
MS_EXCEPTION_IF_NULL(paddings);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
// if shape of x is determined and padding value is unknown, return a all -1 shape
if (paddings->isa<AnyValue>() || paddings->isa<None>()) {
return std::make_shared<abstract::Shape>(ShapeVector(x_shape.size(), abstract::Shape::kShapeDimAny));
}
auto paddings_arg = CheckAndConvertUtils::CheckTensorIntValue(kPaddings, paddings, prim_name);
std::vector<std::pair<int64_t, int64_t>> paddings_attr;
auto paddings_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto mode = GetValue<std::string>(primitive->GetAttr(kMode));
CheckPaddingParam(paddings_shape, x_shape, prim_name);
for (size_t i = 0; i < paddings_arg.size(); i = i + static_cast<size_t>(kPaddingsSecondDimSize)) {
paddings_attr.push_back(std::make_pair(paddings_arg[i], paddings_arg[i + 1]));
}
(void)CheckAndConvertUtils::CheckInteger(kPaddingsSize, SizeToLong(paddings_attr.size()), kEqual,
SizeToLong(x_shape.size()), prim_name);
int64_t size = static_cast<int64_t>(x_shape.size()); int64_t size = static_cast<int64_t>(x_shape.size());
if (size < 0 || size > MAX_PADDINGS) { if (size < 0 || size > MAX_PADDINGS) {
MS_EXCEPTION(ValueError) << "For '" << prim_name MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', the dimension of input only supports less than or equal to 5 dims, but got " << "', the dimension of input only supports less than or equal to 5 dims, but got " << size
<< size << " dims"; << " dims";
} }
for (int64_t i = 0; i < size; i++) { for (int64_t i = 0; i < size; i++) {
if (x_shape[i] == abstract::Shape::kShapeDimAny) { if (x_shape[i] == abstract::Shape::kShapeDimAny) {
@ -116,6 +86,44 @@ class MirrorPadInfer : public abstract::OpInferBase {
} }
} }
} }
}
MIND_API_OPERATOR_IMPL(MirrorPad, BaseOperator);
class MirrorPadInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto input_x_shape_ptr = input_args[0]->BuildShape();
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
auto input_x_shape = input_x_shape_ptr->cast<abstract::ShapePtr>();
// Dynamic rank process.
if (IsDynamicRank(input_x_shape->shape())) {
return std::make_shared<abstract::Shape>(ShapeVector{abstract::Shape::kShapeRankAny});
}
auto paddings = input_args[1]->BuildValue();
MS_EXCEPTION_IF_NULL(paddings);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
// if shape of x is determined and padding value is unknown, return a all -1 shape
if (paddings->isa<AnyValue>() || paddings->isa<None>()) {
return std::make_shared<abstract::Shape>(ShapeVector(x_shape.size(), abstract::Shape::kShapeDimAny));
}
auto paddings_arg = CheckAndConvertUtils::CheckTensorIntValue(kPaddings, paddings, prim_name);
std::vector<std::pair<int64_t, int64_t>> paddings_attr;
auto paddings_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto mode = GetValue<std::string>(primitive->GetAttr(kMode));
CheckPaddingParam(paddings_shape, x_shape, prim_name);
for (size_t i = 0; i < paddings_arg.size(); i = i + static_cast<size_t>(kPaddingsSecondDimSize)) {
paddings_attr.push_back(std::make_pair(paddings_arg[i], paddings_arg[i + 1]));
}
(void)CheckAndConvertUtils::CheckInteger(kPaddingsSize, SizeToLong(paddings_attr.size()), kEqual,
SizeToLong(x_shape.size()), prim_name);
CheckPaddingValue(paddings_attr, x_shape, mode, prim_name);
std::vector<int64_t> out_shape; std::vector<int64_t> out_shape;
for (size_t i = 0; i < x_shape.size(); i++) { for (size_t i = 0; i < x_shape.size(); i++) {
// In dynamic situation , if input axis is dynamic, output axis is dynamic too. // In dynamic situation , if input axis is dynamic, output axis is dynamic too.

View File

@ -31,6 +31,9 @@ namespace {
TypePtr PaddingInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { TypePtr PaddingInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto name = primitive->name(); auto name = primitive->name();
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto context = MsContext::GetInstance(); auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice); bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);

View File

@ -4298,9 +4298,9 @@ class PadV3(Primitive):
ValueError: If `mode` is "constant", the element's number of `paddings` large than input dim * 2. ValueError: If `mode` is "constant", the element's number of `paddings` large than input dim * 2.
ValueError: If `mode` is "edge" or "reflect", the element's number of `paddings` is not 2, 4 or 6. ValueError: If `mode` is "edge" or "reflect", the element's number of `paddings` is not 2, 4 or 6.
ValueError: If `mode` is "edge" or "reflect", `x` dims equals 3, ValueError: If `mode` is "edge" or "reflect", `x` dims equals 3,
the element's number of `paddings` is 2. the element's number of `paddings` is not 2.
ValueError: If `mode` is "edge" or "reflect", `x` dims equals 4, ValueError: If `mode` is "edge" or "reflect", `x` dims equals 4,
the element's number of `paddings` is 4. the element's number of `paddings` is not 4.
ValueError: If `mode` is "edge" or "reflect", `x` dims smaller than 3. ValueError: If `mode` is "edge" or "reflect", `x` dims smaller than 3.
ValueError: If `mode` is "edge", x dims bigger than 5. ValueError: If `mode` is "edge", x dims bigger than 5.
ValueError: If `mode` is "reflect", x dims bigger than 4. ValueError: If `mode` is "reflect", x dims bigger than 4.