support dynamic for class _ConstantPadNd(Cell)

This commit is contained in:
ckey_Dou 2023-02-16 12:13:31 +08:00
parent 65f36d8def
commit 029c2a3227
2 changed files with 20 additions and 5 deletions

View File

@ -69,6 +69,22 @@ class MIND_API AGOnesLikeInfer : public abstract::OpInferBase {
const std::vector<AbstractBasePtr> &input_args) const override {
return OnesLikeInfer(engine, primitive, input_args);
}
ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
if (input_args.empty()) {
return nullptr;
}
auto op_name = primitive->name();
auto shape_ptr = CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, 0);
auto shape_vec = shape_ptr->shape();
if (IsDynamic(shape_vec)) {
return nullptr;
}
auto infer_type = input_args[0]->BuildType();
return TensorConstructUtils::CreateOnesTensor(infer_type, shape_vec);
}
};
REGISTER_PRIMITIVE_OP_INFER_IMPL(OnesLike, prim::kPrimOnesLike, AGOnesLikeInfer, false);

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from mindspore.common import Tensor
from mindspore import ops
from mindspore.ops.operations import nn_ops
from mindspore.ops.primitive import constexpr
from mindspore.ops.primitive import constexpr, _primexpr
from mindspore.nn.cell import Cell
__all__ = ['ConstantPad1d', 'ConstantPad2d', 'ConstantPad3d', 'ReflectionPad1d', 'ReflectionPad2d', 'ReflectionPad3d',
@ -83,7 +83,7 @@ def _get_new_padding(padding):
return new_padding, start, end
@constexpr
@_primexpr
def _get_begin_size(shape, begin, end):
"""Calculate begin and size for ops.Slice."""
size = tuple([shape[i] + begin[i] + end[i] for i in range(len(shape))])
@ -166,13 +166,12 @@ class _ConstantPadNd(Cell):
def construct(self, x):
"""Construct the pad net."""
input_shape = x.shape
input_type = x.dtype
padding = _check(input_shape, self.padding)
new_padding, start, end = _get_new_padding(padding)
mask = ops.Ones()(input_shape, input_type)
mask = ops.OnesLike()(x)
output = ops.Pad(new_padding)(x)
mask = ops.Pad(new_padding)(mask)
ones = ops.Ones()(output.shape, output.dtype)
ones = ops.OnesLike()(output)
value = ops.Fill()(output.dtype, output.shape, self.value)
output = ops.Add()(ops.Mul()(mask, output), ops.Mul()(ops.Sub()(ones, mask), value))
slice_op = ops.Slice()