forked from mindspore-Ecosystem/mindspore
pynative-support-reducemean
This commit is contained in:
parent
57f6fa6439
commit
c6b2b0df1e
|
@ -21,6 +21,7 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
|
@ -59,7 +60,7 @@ struct OpExecInfo {
|
||||||
using OpExecInfoPtr = std::shared_ptr<OpExecInfo>;
|
using OpExecInfoPtr = std::shared_ptr<OpExecInfo>;
|
||||||
OpExecInfoPtr GenerateOpExecInfo(const py::args& args);
|
OpExecInfoPtr GenerateOpExecInfo(const py::args& args);
|
||||||
|
|
||||||
const std::unordered_set<std::string> ignore_infer_prim = {"partial"};
|
const std::set<std::string> ignore_infer_prim = {"partial", "make_ref"};
|
||||||
|
|
||||||
} // namespace pynative
|
} // namespace pynative
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -24,7 +24,7 @@ from ..._checkparam import Rel
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from ...common.tensor import Tensor
|
from ...common.tensor import Tensor
|
||||||
from .._utils import _get_broadcast_shape
|
from .._utils import _get_broadcast_shape
|
||||||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op
|
||||||
|
|
||||||
|
|
||||||
def _infer_shape_reduce(x, axis, keep_dims, prim_name):
|
def _infer_shape_reduce(x, axis, keep_dims, prim_name):
|
||||||
|
@ -225,6 +225,11 @@ class _Reduce(PrimitiveWithInfer):
|
||||||
validator.check_value_type('keep_dims', keep_dims, [bool], self.name)
|
validator.check_value_type('keep_dims', keep_dims, [bool], self.name)
|
||||||
self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y'])
|
self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y'])
|
||||||
|
|
||||||
|
def __call__(self, x, axis=()):
|
||||||
|
args = [x, axis]
|
||||||
|
output = _run_op(self, self.name, args)
|
||||||
|
return output
|
||||||
|
|
||||||
def do_infer(self, input_x, axis, valid_dtype=mstype.number_type):
|
def do_infer(self, input_x, axis, valid_dtype=mstype.number_type):
|
||||||
axis_v = axis['value']
|
axis_v = axis['value']
|
||||||
input_shp = input_x['shape']
|
input_shp = input_x['shape']
|
||||||
|
|
Loading…
Reference in New Issue