forked from mindspore-Ecosystem/mindspore
Fix ms function bug
Signed-off-by: zjun <zhangjun0@huawei.com>
This commit is contained in:
parent
d021c11ddc
commit
f3b4e35146
|
@ -30,16 +30,21 @@ namespace mindspore {
|
|||
namespace pynative {
|
||||
namespace {
|
||||
const char kAddedValue[] = "added_value";
|
||||
const mindspore::HashSet<std::string> kNotRealOP{prim::kPrimMakeTuple->name(),
|
||||
prim::kPrimTupleGetItem->name(),
|
||||
prim::kPrimStopGradient->name(),
|
||||
prim::kPrimUpdateState->name(),
|
||||
prim::kPrimLoad->name(),
|
||||
prim::kPrimDepend->name(),
|
||||
prim::kPrimReturn->name(),
|
||||
prim::kPrimNPUAllocFloatStatus->name(),
|
||||
prim::kPrimNPUGetFloatStatus->name(),
|
||||
prim::kPrimNPUClearFloatStatus->name()};
|
||||
const mindspore::HashSet<std::string> kNotRealOP{
|
||||
prim::kPrimMakeTuple->name(),
|
||||
prim::kPrimTupleGetItem->name(),
|
||||
prim::kPrimStopGradient->name(),
|
||||
prim::kPrimUpdateState->name(),
|
||||
prim::kPrimLoad->name(),
|
||||
prim::kPrimDepend->name(),
|
||||
prim::kPrimReturn->name(),
|
||||
prim::kPrimNPUAllocFloatStatus->name(),
|
||||
prim::kPrimNPUGetFloatStatus->name(),
|
||||
prim::kPrimNPUClearFloatStatus->name(),
|
||||
prim::kPrimMirror->name(),
|
||||
prim::kPrimPyExecute->name(),
|
||||
prim::kPrimPyInterpret->name(),
|
||||
};
|
||||
|
||||
FrontendOpRunInfoPtr GetOpRunInfo(const py::object &out, const py::args &args, const std::string &graph_phase,
|
||||
ValuePtr *added_out_v) {
|
||||
|
@ -306,7 +311,9 @@ void MsFunction::MakeCNodeForMsFunction(const FrontendOpRunInfoPtr &op_run_info,
|
|||
// Make a CNode which includes ms_function fprop graph and inputs node
|
||||
MS_EXCEPTION_IF_NULL(ms_function_cnode);
|
||||
*ms_function_cnode = grad_executor->top_cell()->fg()->NewCNode(input_nodes);
|
||||
(*ms_function_cnode)->set_abstract(ms_func_graph->output()->abstract());
|
||||
// If ms function is dynamic shape, used actual shape in pynative mode
|
||||
(*ms_function_cnode)
|
||||
->set_abstract(PyNativeAlgo::Common::SetAbstractValueToAnyValue(op_run_info->out_value->ToAbstract()));
|
||||
MS_LOG(DEBUG) << "Make ms function forward CNode: " << (*ms_function_cnode)->DebugString();
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue