!499 pynative support topk and print op

Merge pull request !499 from JoyLvliang/pynative-support-topk-and-print
This commit is contained in:
mindspore-ci-bot 2020-04-21 15:58:51 +08:00 committed by Gitee
commit 475f62f680
5 changed files with 9 additions and 3 deletions

View File

@ -135,10 +135,11 @@ void KernelRuntime::AssignMemory(session::KernelGraph *graph) {
}
void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors,
const session::KernelGraph *graph) {
session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
// assign memory for input nodes
RunOpAssignInputMemory(input_tensors, graph);
AssignStaticMemoryValueNode(graph);
for (const auto &cnode : graph->execution_order()) {
// assign memory for output nodes
RunOpAssignOutputMemory(cnode);

View File

@ -46,7 +46,7 @@ class KernelRuntime {
virtual ~KernelRuntime();
virtual bool Init() = 0;
virtual void AssignMemory(session::KernelGraph *graph);
void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph *graph);
void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, session::KernelGraph *graph);
virtual bool Run(session::KernelGraph *graph);
virtual bool DumpData(session::KernelGraph *graph);
virtual bool RunTask(const session::KernelGraph *graph);

View File

@ -222,6 +222,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
auto optimizer = std::make_shared<GraphOptimizer>();
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
optimizer->AddPassManager(ir_fusion_pm);
(void)optimizer->Optimize(kernel_graph);

View File

@ -25,7 +25,7 @@ assign_add_op_info = TBERegOp("AssignAdd") \
.partial_flag(True) \
.input(0, "ref", False, "required", "all") \
.input(1, "value", False, "required", "all") \
.output(0, "output_ref", False, "required", "all") \
.output(0, "ref", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \

View File

@ -210,6 +210,10 @@ class Print(PrimitiveWithInfer):
def __init__(self):
pass
def __call__(self, *args):
for arg in args:
print(arg)
def infer_shape(self, *inputs):
return [1]