forked from OSSInnovation/mindspore
!499 pynative support topk and print op
Merge pull request !499 from JoyLvliang/pynative-support-topk-and-print
This commit is contained in:
commit
475f62f680
|
@ -135,10 +135,11 @@ void KernelRuntime::AssignMemory(session::KernelGraph *graph) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors,
|
void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors,
|
||||||
const session::KernelGraph *graph) {
|
session::KernelGraph *graph) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
// assign memory for input nodes
|
// assign memory for input nodes
|
||||||
RunOpAssignInputMemory(input_tensors, graph);
|
RunOpAssignInputMemory(input_tensors, graph);
|
||||||
|
AssignStaticMemoryValueNode(graph);
|
||||||
for (const auto &cnode : graph->execution_order()) {
|
for (const auto &cnode : graph->execution_order()) {
|
||||||
// assign memory for output nodes
|
// assign memory for output nodes
|
||||||
RunOpAssignOutputMemory(cnode);
|
RunOpAssignOutputMemory(cnode);
|
||||||
|
|
|
@ -46,7 +46,7 @@ class KernelRuntime {
|
||||||
virtual ~KernelRuntime();
|
virtual ~KernelRuntime();
|
||||||
virtual bool Init() = 0;
|
virtual bool Init() = 0;
|
||||||
virtual void AssignMemory(session::KernelGraph *graph);
|
virtual void AssignMemory(session::KernelGraph *graph);
|
||||||
void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph *graph);
|
void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, session::KernelGraph *graph);
|
||||||
virtual bool Run(session::KernelGraph *graph);
|
virtual bool Run(session::KernelGraph *graph);
|
||||||
virtual bool DumpData(session::KernelGraph *graph);
|
virtual bool DumpData(session::KernelGraph *graph);
|
||||||
virtual bool RunTask(const session::KernelGraph *graph);
|
virtual bool RunTask(const session::KernelGraph *graph);
|
||||||
|
|
|
@ -222,6 +222,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
|
||||||
auto optimizer = std::make_shared<GraphOptimizer>();
|
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||||
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
|
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
|
||||||
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
|
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
|
||||||
|
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
|
||||||
|
|
||||||
optimizer->AddPassManager(ir_fusion_pm);
|
optimizer->AddPassManager(ir_fusion_pm);
|
||||||
(void)optimizer->Optimize(kernel_graph);
|
(void)optimizer->Optimize(kernel_graph);
|
||||||
|
|
|
@ -25,7 +25,7 @@ assign_add_op_info = TBERegOp("AssignAdd") \
|
||||||
.partial_flag(True) \
|
.partial_flag(True) \
|
||||||
.input(0, "ref", False, "required", "all") \
|
.input(0, "ref", False, "required", "all") \
|
||||||
.input(1, "value", False, "required", "all") \
|
.input(1, "value", False, "required", "all") \
|
||||||
.output(0, "output_ref", False, "required", "all") \
|
.output(0, "ref", False, "required", "all") \
|
||||||
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||||
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \
|
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \
|
||||||
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
|
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||||
|
|
|
@ -210,6 +210,10 @@ class Print(PrimitiveWithInfer):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def __call__(self, *args):
|
||||||
|
for arg in args:
|
||||||
|
print(arg)
|
||||||
|
|
||||||
def infer_shape(self, *inputs):
|
def infer_shape(self, *inputs):
|
||||||
return [1]
|
return [1]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue