forked from mindspore-Ecosystem/mindspore
support dynamic for class _ConstantPadNd(Cell)
This commit is contained in:
parent
65f36d8def
commit
029c2a3227
|
@ -69,6 +69,22 @@ class MIND_API AGOnesLikeInfer : public abstract::OpInferBase {
|
|||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return OnesLikeInfer(engine, primitive, input_args);
|
||||
}
|
||||
|
||||
ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
if (input_args.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto op_name = primitive->name();
|
||||
auto shape_ptr = CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, 0);
|
||||
auto shape_vec = shape_ptr->shape();
|
||||
if (IsDynamic(shape_vec)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto infer_type = input_args[0]->BuildType();
|
||||
return TensorConstructUtils::CreateOnesTensor(infer_type, shape_vec);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(OnesLike, prim::kPrimOnesLike, AGOnesLikeInfer, false);
|
||||
|
|
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
|||
from mindspore.common import Tensor
|
||||
from mindspore import ops
|
||||
from mindspore.ops.operations import nn_ops
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops.primitive import constexpr, _primexpr
|
||||
from mindspore.nn.cell import Cell
|
||||
|
||||
__all__ = ['ConstantPad1d', 'ConstantPad2d', 'ConstantPad3d', 'ReflectionPad1d', 'ReflectionPad2d', 'ReflectionPad3d',
|
||||
|
@ -83,7 +83,7 @@ def _get_new_padding(padding):
|
|||
return new_padding, start, end
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _get_begin_size(shape, begin, end):
|
||||
"""Calculate begin and size for ops.Slice."""
|
||||
size = tuple([shape[i] + begin[i] + end[i] for i in range(len(shape))])
|
||||
|
@ -166,13 +166,12 @@ class _ConstantPadNd(Cell):
|
|||
def construct(self, x):
|
||||
"""Construct the pad net."""
|
||||
input_shape = x.shape
|
||||
input_type = x.dtype
|
||||
padding = _check(input_shape, self.padding)
|
||||
new_padding, start, end = _get_new_padding(padding)
|
||||
mask = ops.Ones()(input_shape, input_type)
|
||||
mask = ops.OnesLike()(x)
|
||||
output = ops.Pad(new_padding)(x)
|
||||
mask = ops.Pad(new_padding)(mask)
|
||||
ones = ops.Ones()(output.shape, output.dtype)
|
||||
ones = ops.OnesLike()(output)
|
||||
value = ops.Fill()(output.dtype, output.shape, self.value)
|
||||
output = ops.Add()(ops.Mul()(mask, output), ops.Mul()(ops.Sub()(ones, mask), value))
|
||||
slice_op = ops.Slice()
|
||||
|
|
Loading…
Reference in New Issue