forked from OSSInnovation/mindspore
!499 pynative support topk and print op
Merge pull request !499 from JoyLvliang/pynative-support-topk-and-print
This commit is contained in:
commit
475f62f680
|
@ -135,10 +135,11 @@ void KernelRuntime::AssignMemory(session::KernelGraph *graph) {
|
|||
}
|
||||
|
||||
void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const session::KernelGraph *graph) {
|
||||
session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// assign memory for input nodes
|
||||
RunOpAssignInputMemory(input_tensors, graph);
|
||||
AssignStaticMemoryValueNode(graph);
|
||||
for (const auto &cnode : graph->execution_order()) {
|
||||
// assign memory for output nodes
|
||||
RunOpAssignOutputMemory(cnode);
|
||||
|
|
|
@ -46,7 +46,7 @@ class KernelRuntime {
|
|||
virtual ~KernelRuntime();
|
||||
virtual bool Init() = 0;
|
||||
virtual void AssignMemory(session::KernelGraph *graph);
|
||||
void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph *graph);
|
||||
void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, session::KernelGraph *graph);
|
||||
virtual bool Run(session::KernelGraph *graph);
|
||||
virtual bool DumpData(session::KernelGraph *graph);
|
||||
virtual bool RunTask(const session::KernelGraph *graph);
|
||||
|
|
|
@ -222,6 +222,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
|
|||
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
|
||||
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
|
||||
|
||||
optimizer->AddPassManager(ir_fusion_pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
|
|
|
@ -25,7 +25,7 @@ assign_add_op_info = TBERegOp("AssignAdd") \
|
|||
.partial_flag(True) \
|
||||
.input(0, "ref", False, "required", "all") \
|
||||
.input(1, "value", False, "required", "all") \
|
||||
.output(0, "output_ref", False, "required", "all") \
|
||||
.output(0, "ref", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
|
|
|
@ -210,6 +210,10 @@ class Print(PrimitiveWithInfer):
|
|||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, *args):
|
||||
for arg in args:
|
||||
print(arg)
|
||||
|
||||
def infer_shape(self, *inputs):
|
||||
return [1]
|
||||
|
||||
|
|
Loading…
Reference in New Issue