fix top type

This commit is contained in:
reku1997 2022-08-16 09:41:19 +08:00
parent dfa2cd50ba
commit 73b7a3de36
4 changed files with 42 additions and 5 deletions

View File

@ -172,6 +172,7 @@ PYBIND11_MODULE(_c_expression, m) {
(void)m.def("init_cluster", &mindspore::distributed::Initialize, "Init Cluster");
(void)m.def("get_dyn_shape", &mindspore::pynative::GetDynShape, "Get Dynamic Shape of Tensor");
(void)m.def("call_constant_folding", &mindspore::pynative::CallConstantFolding, "Call Constant Folding Primitive");
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")

View File

@ -18,6 +18,7 @@
#include "pipeline/jit/debug/trace.h"
#include "pybind_api/pybind_patch.h"
#include "include/common/utils/config_manager.h"
#include "include/common/utils/convert_utils_py.h"
#include "include/common/pybind_api/api_register.h"
#include "frontend/optimizer/ad/grad.h"
#include "pipeline/jit/pass.h"
@ -92,6 +93,31 @@ py::object GetDynShape(const py::args &args) {
return executor->forward_executor()->dynamic_shape()->GetDynShape(args);
}
py::object CallConstantFolding(const py::args &args) {
const auto &prim_arg = args[0];
const auto &adapter = py::cast<PrimitivePyAdapterPtr>(prim_arg);
MS_EXCEPTION_IF_NULL(adapter);
auto prim = adapter->attached_primitive();
if (prim == nullptr) {
prim = std::make_shared<PrimitivePy>(prim_arg, adapter);
adapter->set_attached_primitive(prim);
}
if (!prim->HasPyObj()) {
MS_LOG(EXCEPTION) << "Pyobj is empty";
}
const auto &v = PyNativeAlgo::DataConvert::PyObjToValue(args[1]);
std::vector<AbstractBasePtr> input_abs;
input_abs.push_back(v->ToAbstract());
prim->BeginRecordAddAttr();
auto eval_ret = EvalOnePrim(prim, input_abs);
MS_EXCEPTION_IF_NULL(eval_ret);
AbstractBasePtr infer_res = eval_ret->abstract();
MS_EXCEPTION_IF_NULL(infer_res);
prim->EndRecordAddAttr();
auto value_ptr = PyNativeAlgo::DataConvert::PyObjToValue(ConvertAbstractToPython(infer_res)[ATTR_VALUE]);
return ValueToPyData(value_ptr);
}
void PyNativeExecutor::set_py_exe_path(const py::object &py_exe_path) const {
if (!py::isinstance<py::str>(py_exe_path)) {
MS_LOG(EXCEPTION) << "Failed, py_exe_path input is not a str";

View File

@ -19,6 +19,7 @@
#include <memory>
#include <string>
#include <vector>
#include "pipeline/pynative/forward/forward.h"
#include "pipeline/pynative/grad/grad.h"
#include "pybind11/pybind11.h"
@ -34,6 +35,7 @@ using GradExecutorPtr = std::shared_ptr<GradExecutor>;
py::object RealRunOp(const py::args &args);
py::object GetDynShape(const py::args &args);
py::object CallConstantFolding(const py::args &args);
class PyNativeExecutor : public std::enable_shared_from_this<PyNativeExecutor> {
public:

View File

@ -29,6 +29,7 @@ from ..._checkparam import Validator as validator
from ...common import dtype as mstype
from ...common.parameter import Parameter
from ...communication.management import GlobalComm
from ..._c_expression import call_constant_folding
class FillV2(Primitive):
@ -2094,18 +2095,25 @@ class TopTypeof(Primitive):
@prim_attr_register
def __init__(self):
self.typeof_cache = dict()
self.prim_toptypeof = Primitive('TopTypeof')
self.prim = Primitive('TopTypeof')
self.typeof_cache = {
'slice': call_constant_folding(self.prim, slice(None, None, None)),
'list': call_constant_folding(self.prim, []),
'tuple': call_constant_folding(self.prim, ()),
'Tensor': call_constant_folding(self.prim, Tensor(np.ones([1,], dtype=np.float32))),
'NoneType': call_constant_folding(self.prim, None),
'int': call_constant_folding(self.prim, 0),
'bool': call_constant_folding(self.prim, False),
'ellipsis': call_constant_folding(self.prim, ...)
}
def __call__(self, x):
index_type = type(x).__name__
if 'Tensor' in index_type:
index_type = 'Tensor'
if index_type in ('slice', 'list', 'tuple', 'Tensor', 'NoneType', 'int', 'bool', 'ellipsis'):
if index_type not in self.typeof_cache:
self.typeof_cache[index_type] = self.prim_toptypeof(x)
return self.typeof_cache.get(index_type)
return self.prim_toptypeof(x)
return call_constant_folding(self.prim, x)
class MixedPrecisionCast(Primitive):