forked from mindspore-Ecosystem/mindspore
fix pad ops doc and nulltpr
This commit is contained in:
parent
680f21a547
commit
5788279b3b
|
@ -24,8 +24,8 @@ mindspore.ops.PadV3
|
||||||
- **ValueError** - `mode` 是"constant"的同时 `paddings` 元素个数不是偶数。
|
- **ValueError** - `mode` 是"constant"的同时 `paddings` 元素个数不是偶数。
|
||||||
- **ValueError** - `mode` 是"constant"的同时 `paddings` 元素个数大于输入维度乘以2。
|
- **ValueError** - `mode` 是"constant"的同时 `paddings` 元素个数大于输入维度乘以2。
|
||||||
- **ValueError** - `mode` 是"edge"或"reflect"的同时 `paddings` 元素个数不是2、4或6。
|
- **ValueError** - `mode` 是"edge"或"reflect"的同时 `paddings` 元素个数不是2、4或6。
|
||||||
- **ValueError** - `mode` 是"edge"或"reflect", `x` 的维度是3, `paddings` 元素个数是2。
|
- **ValueError** - `mode` 是"edge"或"reflect", `x` 的维度是3, `paddings` 元素个数不是2。
|
||||||
- **ValueError** - `mode` 是"edge"或"reflect", `x` 的维度是4, `paddings` 元素个数是4。
|
- **ValueError** - `mode` 是"edge"或"reflect", `x` 的维度是4, `paddings` 元素个数不是4。
|
||||||
- **ValueError** - `mode` 是"edge"或"reflect"的同时 `x` 的维度小于3。
|
- **ValueError** - `mode` 是"edge"或"reflect"的同时 `x` 的维度小于3。
|
||||||
- **ValueError** - `mode` 是"edge"的同时 `x` 的维度大于5。
|
- **ValueError** - `mode` 是"edge"的同时 `x` 的维度大于5。
|
||||||
- **ValueError** - `mode` 是"reflect"的同时 `x` 的维度大于4。
|
- **ValueError** - `mode` 是"reflect"的同时 `x` 的维度大于4。
|
||||||
|
|
|
@ -51,43 +51,13 @@ void CheckPaddingParam(const std::vector<int64_t> &paddings_shape, const std::ve
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
MIND_API_OPERATOR_IMPL(MirrorPad, BaseOperator);
|
void CheckPaddingValue(const std::vector<std::pair<int64_t, int64_t>> &paddings_attr,
|
||||||
class MirrorPadInfer : public abstract::OpInferBase {
|
const std::vector<int64_t> &x_shape, const std::string &mode, const std::string &prim_name) {
|
||||||
public:
|
|
||||||
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
|
||||||
auto prim_name = primitive->name();
|
|
||||||
auto input_x_shape_ptr = input_args[0]->BuildShape();
|
|
||||||
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
|
|
||||||
auto input_x_shape = input_x_shape_ptr->cast<abstract::ShapePtr>();
|
|
||||||
// Dynamic rank process.
|
|
||||||
if (IsDynamicRank(input_x_shape->shape())) {
|
|
||||||
return std::make_shared<abstract::Shape>(ShapeVector{abstract::Shape::kShapeRankAny});
|
|
||||||
}
|
|
||||||
auto paddings = input_args[1]->BuildValue();
|
|
||||||
MS_EXCEPTION_IF_NULL(paddings);
|
|
||||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
|
||||||
// if shape of x is determined and padding value is unknown, return a all -1 shape
|
|
||||||
if (paddings->isa<AnyValue>() || paddings->isa<None>()) {
|
|
||||||
return std::make_shared<abstract::Shape>(ShapeVector(x_shape.size(), abstract::Shape::kShapeDimAny));
|
|
||||||
}
|
|
||||||
auto paddings_arg = CheckAndConvertUtils::CheckTensorIntValue(kPaddings, paddings, prim_name);
|
|
||||||
std::vector<std::pair<int64_t, int64_t>> paddings_attr;
|
|
||||||
|
|
||||||
auto paddings_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
|
||||||
auto mode = GetValue<std::string>(primitive->GetAttr(kMode));
|
|
||||||
CheckPaddingParam(paddings_shape, x_shape, prim_name);
|
|
||||||
for (size_t i = 0; i < paddings_arg.size(); i = i + static_cast<size_t>(kPaddingsSecondDimSize)) {
|
|
||||||
paddings_attr.push_back(std::make_pair(paddings_arg[i], paddings_arg[i + 1]));
|
|
||||||
}
|
|
||||||
(void)CheckAndConvertUtils::CheckInteger(kPaddingsSize, SizeToLong(paddings_attr.size()), kEqual,
|
|
||||||
SizeToLong(x_shape.size()), prim_name);
|
|
||||||
int64_t size = static_cast<int64_t>(x_shape.size());
|
int64_t size = static_cast<int64_t>(x_shape.size());
|
||||||
if (size < 0 || size > MAX_PADDINGS) {
|
if (size < 0 || size > MAX_PADDINGS) {
|
||||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||||
<< "', the dimension of input only supports less than or equal to 5 dims, but got "
|
<< "', the dimension of input only supports less than or equal to 5 dims, but got " << size
|
||||||
<< size << " dims";
|
<< " dims";
|
||||||
}
|
}
|
||||||
for (int64_t i = 0; i < size; i++) {
|
for (int64_t i = 0; i < size; i++) {
|
||||||
if (x_shape[i] == abstract::Shape::kShapeDimAny) {
|
if (x_shape[i] == abstract::Shape::kShapeDimAny) {
|
||||||
|
@ -116,6 +86,44 @@ class MirrorPadInfer : public abstract::OpInferBase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MIND_API_OPERATOR_IMPL(MirrorPad, BaseOperator);
|
||||||
|
class MirrorPadInfer : public abstract::OpInferBase {
|
||||||
|
public:
|
||||||
|
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
for (const auto &item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
|
auto input_x_shape_ptr = input_args[0]->BuildShape();
|
||||||
|
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
|
||||||
|
auto input_x_shape = input_x_shape_ptr->cast<abstract::ShapePtr>();
|
||||||
|
// Dynamic rank process.
|
||||||
|
if (IsDynamicRank(input_x_shape->shape())) {
|
||||||
|
return std::make_shared<abstract::Shape>(ShapeVector{abstract::Shape::kShapeRankAny});
|
||||||
|
}
|
||||||
|
auto paddings = input_args[1]->BuildValue();
|
||||||
|
MS_EXCEPTION_IF_NULL(paddings);
|
||||||
|
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||||
|
// if shape of x is determined and padding value is unknown, return a all -1 shape
|
||||||
|
if (paddings->isa<AnyValue>() || paddings->isa<None>()) {
|
||||||
|
return std::make_shared<abstract::Shape>(ShapeVector(x_shape.size(), abstract::Shape::kShapeDimAny));
|
||||||
|
}
|
||||||
|
auto paddings_arg = CheckAndConvertUtils::CheckTensorIntValue(kPaddings, paddings, prim_name);
|
||||||
|
std::vector<std::pair<int64_t, int64_t>> paddings_attr;
|
||||||
|
|
||||||
|
auto paddings_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||||
|
auto mode = GetValue<std::string>(primitive->GetAttr(kMode));
|
||||||
|
CheckPaddingParam(paddings_shape, x_shape, prim_name);
|
||||||
|
for (size_t i = 0; i < paddings_arg.size(); i = i + static_cast<size_t>(kPaddingsSecondDimSize)) {
|
||||||
|
paddings_attr.push_back(std::make_pair(paddings_arg[i], paddings_arg[i + 1]));
|
||||||
|
}
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger(kPaddingsSize, SizeToLong(paddings_attr.size()), kEqual,
|
||||||
|
SizeToLong(x_shape.size()), prim_name);
|
||||||
|
CheckPaddingValue(paddings_attr, x_shape, mode, prim_name);
|
||||||
std::vector<int64_t> out_shape;
|
std::vector<int64_t> out_shape;
|
||||||
for (size_t i = 0; i < x_shape.size(); i++) {
|
for (size_t i = 0; i < x_shape.size(); i++) {
|
||||||
// In dynamic situation , if input axis is dynamic, output axis is dynamic too.
|
// In dynamic situation , if input axis is dynamic, output axis is dynamic too.
|
||||||
|
|
|
@ -31,6 +31,9 @@ namespace {
|
||||||
TypePtr PaddingInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr PaddingInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto name = primitive->name();
|
auto name = primitive->name();
|
||||||
|
for (const auto &item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
auto context = MsContext::GetInstance();
|
auto context = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
|
bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
|
||||||
|
|
|
@ -4298,9 +4298,9 @@ class PadV3(Primitive):
|
||||||
ValueError: If `mode` is "constant", the element's number of `paddings` large than input dim * 2.
|
ValueError: If `mode` is "constant", the element's number of `paddings` large than input dim * 2.
|
||||||
ValueError: If `mode` is "edge" or "reflect", the element's number of `paddings` is not 2, 4 or 6.
|
ValueError: If `mode` is "edge" or "reflect", the element's number of `paddings` is not 2, 4 or 6.
|
||||||
ValueError: If `mode` is "edge" or "reflect", `x` dims equals 3,
|
ValueError: If `mode` is "edge" or "reflect", `x` dims equals 3,
|
||||||
the element's number of `paddings` is 2.
|
the element's number of `paddings` is not 2.
|
||||||
ValueError: If `mode` is "edge" or "reflect", `x` dims equals 4,
|
ValueError: If `mode` is "edge" or "reflect", `x` dims equals 4,
|
||||||
the element's number of `paddings` is 4.
|
the element's number of `paddings` is not 4.
|
||||||
ValueError: If `mode` is "edge" or "reflect", `x` dims smaller than 3.
|
ValueError: If `mode` is "edge" or "reflect", `x` dims smaller than 3.
|
||||||
ValueError: If `mode` is "edge", x dims bigger than 5.
|
ValueError: If `mode` is "edge", x dims bigger than 5.
|
||||||
ValueError: If `mode` is "reflect", x dims bigger than 4.
|
ValueError: If `mode` is "reflect", x dims bigger than 4.
|
||||||
|
|
Loading…
Reference in New Issue