forked from mindspore-Ecosystem/mindspore
Support non-tail recursive graphs
This commit is contained in:
parent
d9f6a6277d
commit
ba65fb9f3c
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -92,7 +92,9 @@ bool SetIOSize(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<A
|
|||
return false;
|
||||
}
|
||||
kernel_mod_ptr->SetInputSizeList(input_size_list);
|
||||
|
||||
if (output_num == 1 && HasAbstractMonad(anf_node)) {
|
||||
output_num = 0;
|
||||
}
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
std::vector<size_t> shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i);
|
||||
TypePtr type_ptr = TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, i));
|
||||
|
@ -229,6 +231,9 @@ void SetNodeOutputs(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef
|
|||
MS_EXCEPTION_IF_NULL(proto);
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node);
|
||||
if (output_num == 1 && HasAbstractMonad(anf_node)) {
|
||||
output_num = 0;
|
||||
}
|
||||
if (output_num == 0) {
|
||||
MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have output. ";
|
||||
return;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -38,32 +38,10 @@ void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<
|
|||
return;
|
||||
}
|
||||
// For compatibility with the current framework
|
||||
if (op_name == kPrint || op_name == kGetNext || op_name == kPack || op_name == kMeshgrid) {
|
||||
std::vector<std::string> inputs_format{};
|
||||
std::vector<TypeId> inputs_type{};
|
||||
if (op_name == kPrint || op_name == kPack || op_name == kMeshgrid) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index));
|
||||
}
|
||||
}
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> outputs_type;
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
outputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
|
||||
}
|
||||
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
|
||||
builder.SetInputsFormat(inputs_format);
|
||||
builder.SetInputsDeviceType(inputs_type);
|
||||
builder.SetOutputsFormat(outputs_format);
|
||||
builder.SetOutputsDeviceType(outputs_type);
|
||||
builder.SetProcessor(AICPU);
|
||||
builder.SetKernelType(AICPU_KERNEL);
|
||||
builder.SetFusionType(OPAQUE);
|
||||
kernel_info_list->push_back(builder.Build());
|
||||
if (op_name == kPrint || op_name == kGetNext || op_name == kPack || op_name == kMeshgrid ||
|
||||
op_name == kStackInitOpName || op_name == kStackDestroyOpName || op_name == kStackPushOpName ||
|
||||
op_name == kStackPopOpName) {
|
||||
AicpuMetadataInfoForSpecialNodes(kernel_node, kernel_info_list);
|
||||
return;
|
||||
}
|
||||
if (!ParseMetadata(kernel_node, op_info_ptr, AICPU, kernel_info_list)) {
|
||||
|
@ -71,5 +49,37 @@ void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<
|
|||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void AicpuMetadataInfoForSpecialNodes(const CNodePtr &kernel_node,
|
||||
std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
|
||||
std::vector<std::string> inputs_format{};
|
||||
std::vector<TypeId> inputs_type{};
|
||||
auto op_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (op_name == kPrint || op_name == kPack || op_name == kMeshgrid || op_name == kStackInitOpName ||
|
||||
op_name == kStackDestroyOpName || op_name == kStackPushOpName || op_name == kStackPopOpName) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index));
|
||||
}
|
||||
}
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> outputs_type;
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
outputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
|
||||
}
|
||||
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
|
||||
builder.SetInputsFormat(inputs_format);
|
||||
builder.SetInputsDeviceType(inputs_type);
|
||||
builder.SetOutputsFormat(outputs_format);
|
||||
builder.SetOutputsDeviceType(outputs_type);
|
||||
builder.SetProcessor(AICPU);
|
||||
builder.SetKernelType(AICPU_KERNEL);
|
||||
builder.SetFusionType(OPAQUE);
|
||||
kernel_info_list->push_back(builder.Build());
|
||||
return;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -25,6 +25,8 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
|
||||
void AicpuMetadataInfoForSpecialNodes(const CNodePtr &kernel_node,
|
||||
std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AICPU_AICPU_KERNEL_META_DATA_H_
|
||||
|
|
|
@ -154,7 +154,7 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co
|
|||
return nullptr;
|
||||
}
|
||||
auto next_op_name = AnfAlgo::GetCNodeName(next_cnode);
|
||||
if (next_op_name == prim::kPrimSend->name()) {
|
||||
if (next_op_name == prim::kPrimSend->name() || next_op_name == kStackPushOpName) {
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
|
||||
|
@ -229,7 +229,8 @@ AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_nod
|
|||
}
|
||||
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
|
||||
if (AnfAlgo::GetCNodeName(prior_op) == prim::kPrimReceive->name()) {
|
||||
if (AnfAlgo::GetCNodeName(prior_op) == prim::kPrimReceive->name() ||
|
||||
AnfAlgo::GetCNodeName(prior_op) == kStackPopOpName) {
|
||||
return nullptr;
|
||||
}
|
||||
kernel_query->Query(prior_op, &kernel_info_list);
|
||||
|
|
|
@ -106,6 +106,43 @@ std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
|
|||
enum ShapeType { kMaxShape, kMinShape };
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr AnfRuntimeAlgorithm::MakeMonadValueNode(const KernelGraphPtr &kg) {
|
||||
return kg->NewValueNode(kUMonad->ToAbstract(), kUMonad);
|
||||
}
|
||||
|
||||
// Convert:
|
||||
// a = former(xxx)
|
||||
// b = latter(x, xxx)
|
||||
// To:
|
||||
// a = former(xxx)
|
||||
// d1 = Depend(x, a)
|
||||
// b = latter(d1, xxx)
|
||||
// ...
|
||||
// out = Depend(out, latter)
|
||||
void AnfRuntimeAlgorithm::KeepOrder(const KernelGraphPtr &kg, const AnfNodePtr &former, const AnfNodePtr &latter) {
|
||||
if (latter->isa<CNode>()) {
|
||||
auto latter_cnode = latter->cast<CNodePtr>();
|
||||
constexpr size_t inputsize = 2;
|
||||
constexpr size_t kFirstDataInputIndex = 1;
|
||||
if (latter_cnode->inputs().size() < inputsize) {
|
||||
return;
|
||||
}
|
||||
auto latter_input = latter_cnode->input(kFirstDataInputIndex);
|
||||
auto depend1 = kg->NewCNode({NewValueNode(prim::kPrimDepend), latter_input, former});
|
||||
depend1->set_abstract(latter_input->abstract());
|
||||
latter_cnode->set_input(kFirstDataInputIndex, depend1);
|
||||
|
||||
auto return_node = kg->get_return();
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
auto depend2 = kg->NewCNode(
|
||||
{NewValueNode(prim::kPrimDepend), return_node->cast<CNodePtr>()->input(kFirstDataInputIndex), latter});
|
||||
depend2->set_abstract(return_node->cast<CNodePtr>()->input(kFirstDataInputIndex)->abstract());
|
||||
kg->set_output(depend2);
|
||||
MS_LOG(DEBUG) << "former: " << former->DebugString() << ", latter: " << latter->DebugString()
|
||||
<< ", depend1: " << depend1->DebugString() << ", depend2: " << depend2->DebugString();
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr AnfRuntimeAlgorithm::GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) {
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item);
|
||||
if (tuple_get_item->size() != kTupleGetItemInputSize) {
|
||||
|
@ -1529,6 +1566,13 @@ bool AnfRuntimeAlgorithm::IsIndependentNode(const CNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
|
||||
// aicpu stack ops are not independent nodes.
|
||||
if (AnfAlgo::GetCNodeName(node) == kStackInitOpName || AnfAlgo::GetCNodeName(node) == kStackDestroyOpName ||
|
||||
AnfAlgo::GetCNodeName(node) == kStackPopOpName || AnfAlgo::GetCNodeName(node) == kStackPushOpName) {
|
||||
MS_LOG(INFO) << "AICPU stack ops should not be independent node";
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t input_nums = AnfAlgo::GetInputTensorNum(node);
|
||||
if (input_nums == 0) {
|
||||
return true;
|
||||
|
|
|
@ -43,6 +43,8 @@ using DeviceAddress = device::DeviceAddress;
|
|||
using DeviceAddressPtr = device::DeviceAddressPtr;
|
||||
class AnfRuntimeAlgorithm {
|
||||
public:
|
||||
static AnfNodePtr MakeMonadValueNode(const KernelGraphPtr &kg);
|
||||
static void KeepOrder(const KernelGraphPtr &kg, const AnfNodePtr &former, const AnfNodePtr &latter);
|
||||
// get real input node of tuple_get_item
|
||||
static AnfNodePtr GetTupleGetItemRealInput(const CNodePtr &tuple_get_item);
|
||||
static size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item);
|
||||
|
|
|
@ -145,6 +145,9 @@ struct CallSite {
|
|||
// Call/Switch/SwitchLayer
|
||||
CNodePtr cnode;
|
||||
|
||||
// CNode after transferring to LabelGoto/LabelSwitch/LabelSet.
|
||||
CNodePtr conversion_cnode;
|
||||
|
||||
// The last monad before call.
|
||||
AnfNodePtr last_monad = nullptr;
|
||||
|
||||
|
@ -286,6 +289,12 @@ class AscendAutoMonadContext : public BaseContext {
|
|||
|
||||
const KernelGraphPtr &TopGraph() const { return top_graph_; }
|
||||
|
||||
// Has already created an stack.
|
||||
const bool HasInitedStack() const { return inited_stack_; }
|
||||
|
||||
// Set flag to indicate whether has already created an stack or not.
|
||||
void SetInitedStack(bool flag) { inited_stack_ = flag; }
|
||||
|
||||
// Map kernel_graph to its call info.
|
||||
OrderedMap<KernelGraphPtr, CallInfo> call_info_map;
|
||||
|
||||
|
@ -298,6 +307,9 @@ class AscendAutoMonadContext : public BaseContext {
|
|||
|
||||
// Current label id.
|
||||
uint32_t label_id_ = 0;
|
||||
|
||||
// Create an stack for multi-call and non-tail recursion.
|
||||
bool inited_stack_ = false;
|
||||
};
|
||||
|
||||
//
|
||||
|
@ -605,16 +617,22 @@ class AscendAutoMonadConverter {
|
|||
|
||||
private:
|
||||
AscendAutoMonadConverter(const KernelGraphPtr &kg, AscendAutoMonadContext *context, CallInfo *call_info)
|
||||
: kernel_graph_(kg), context_(*context), call_info_(*call_info) {}
|
||||
: kernel_graph_(kg),
|
||||
context_(*context),
|
||||
call_info_(*call_info),
|
||||
name_index_(0),
|
||||
need_stackops_(call_info->recursive) {}
|
||||
~AscendAutoMonadConverter() = default;
|
||||
|
||||
void Run() {
|
||||
// Create an stack
|
||||
InitStack();
|
||||
// Setup entry label if found.
|
||||
SetupEntryLabel();
|
||||
|
||||
// Handle call sites.
|
||||
for (auto &call_site : call_info_.call_sites) {
|
||||
HandleCallSite(call_site);
|
||||
HandleCallSite(&call_site);
|
||||
}
|
||||
// Handle return points.
|
||||
HandleReturnPoints();
|
||||
|
@ -622,20 +640,148 @@ class AscendAutoMonadConverter {
|
|||
if (monad_) {
|
||||
MakeMonadDepend();
|
||||
}
|
||||
// Handle recursive call.
|
||||
kernel_graph_->SetExecOrderByDefault();
|
||||
for (auto &call_site : call_info_.call_sites) {
|
||||
if (need_stackops_ && call_site.recursive) {
|
||||
MS_LOG(INFO) << "graph:" << kernel_graph_->ToString() << ", loop call_site:" << call_site.cnode->DebugString();
|
||||
InsertStackOps(call_site);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void HandleCallSite(const CallSite &call_site) {
|
||||
// Create a Stack for StackOps if needed.
|
||||
void InitStack() {
|
||||
if (!context_.HasInitedStack() && need_stackops_) {
|
||||
auto top_graph = context_.TopGraph();
|
||||
auto exec_order = top_graph->execution_order();
|
||||
auto stack_init = StackInit(top_graph);
|
||||
AnfAlgo::KeepOrder(top_graph, stack_init, *exec_order.begin());
|
||||
auto stack_destroy = StackDestroy(top_graph);
|
||||
AnfAlgo::KeepOrder(top_graph, *exec_order.rbegin(), stack_destroy);
|
||||
top_graph->SetExecOrderByDefault();
|
||||
context_.SetInitedStack(true);
|
||||
}
|
||||
}
|
||||
|
||||
// Insert StackOps for call_site in the recursive graph.
|
||||
void InsertStackOps(const CallSite &call_site) {
|
||||
auto call_point = call_site.conversion_cnode;
|
||||
auto exec_order = kernel_graph_->execution_order();
|
||||
std::vector<AnfNodePtr> before_nodes;
|
||||
std::vector<CNodePtr> stack_pushs;
|
||||
bool find_call_point = false;
|
||||
for (auto &node : exec_order) {
|
||||
auto node_name = AnfAlgo::GetCNodeName(node);
|
||||
if (node == call_point) {
|
||||
find_call_point = true;
|
||||
continue;
|
||||
}
|
||||
if (!find_call_point) {
|
||||
if (node_name == kLabelGotoOpName || node_name == kLabelSwitchOpName || node_name == kLabelSetOpName ||
|
||||
node_name == prim::kPrimAssign->name()) {
|
||||
MS_LOG(DEBUG) << "Ignore goto/switch/set/assign ops";
|
||||
} else {
|
||||
before_nodes.push_back(node);
|
||||
MS_LOG(DEBUG) << "push back node:" << node->DebugString();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (node->size() == 0 || node_name == kLabelGotoOpName || node_name == kLabelSetOpName ||
|
||||
node_name == prim::kPrimAssign->name()) {
|
||||
continue;
|
||||
}
|
||||
FindInputNode(before_nodes, node, &stack_pushs);
|
||||
}
|
||||
InsertStackPush(kernel_graph_, call_point, stack_pushs);
|
||||
}
|
||||
|
||||
// Find nodes which need StackOps, and insert StackOps for node.
|
||||
void FindInputNode(const std::vector<AnfNodePtr> &before_nodes, const CNodePtr &node,
|
||||
std::vector<CNodePtr> *stack_pushs) {
|
||||
uint32_t start_index = 1;
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimAssign)) {
|
||||
start_index = 2;
|
||||
}
|
||||
// auto node_inputs = node->inputs();
|
||||
for (uint32_t i = start_index; i < node->inputs().size(); i++) {
|
||||
auto node_input = node->input(i);
|
||||
// not need to save monad.
|
||||
if (HasAbstractMonad(node_input)) {
|
||||
continue;
|
||||
}
|
||||
MS_LOG(DEBUG) << "check node input[" << i << "]: " << node_input->DebugString();
|
||||
if (node_input->isa<Parameter>()) {
|
||||
MS_LOG(DEBUG) << "node_input:" << node_input->DebugString() << " is a param";
|
||||
CNodePtr stack_pop = InsertStackPop(kernel_graph_, node_input, stack_pushs);
|
||||
node->set_input(i, stack_pop);
|
||||
KeepOrderForStackPop(kernel_graph_, stack_pop, node);
|
||||
continue;
|
||||
}
|
||||
auto iter = std::find_if(before_nodes.begin(), before_nodes.end(),
|
||||
[node_input](auto before_node) { return before_node == node_input; });
|
||||
if (iter != before_nodes.end()) {
|
||||
CNodePtr stack_pop = InsertStackPop(kernel_graph_, *iter, stack_pushs);
|
||||
node->set_input(i, stack_pop);
|
||||
KeepOrderForStackPop(kernel_graph_, stack_pop, node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create StackOps for node_input.
|
||||
CNodePtr InsertStackPop(const KernelGraphPtr &kg, const AnfNodePtr &node_input, std::vector<CNodePtr> *stack_pushs) {
|
||||
auto stack_push = StackPush(node_input);
|
||||
stack_pushs->emplace_back(stack_push);
|
||||
auto stack_pop = StackPop();
|
||||
stack_pop->set_abstract(node_input->abstract());
|
||||
return stack_pop;
|
||||
}
|
||||
|
||||
// Arrange StackPushs according to the rules of the last pop-up StackPush first,
|
||||
// while ensuring that the last StackPush node is next to the jump_node.
|
||||
void InsertStackPush(const KernelGraphPtr &kg, const CNodePtr &jump_node, const std::vector<CNodePtr> &stack_pushs) {
|
||||
MS_LOG(DEBUG) << "There are " << stack_pushs.size() << " stack_push ops";
|
||||
if (stack_pushs.size() < 1) {
|
||||
return;
|
||||
}
|
||||
for (uint32_t i = 1; i < stack_pushs.size(); i++) {
|
||||
AnfAlgo::KeepOrder(kg, stack_pushs[i], stack_pushs[i - 1]);
|
||||
}
|
||||
auto nodes = kg->execution_order();
|
||||
auto node_iter = std::find(nodes.begin(), nodes.end(), jump_node);
|
||||
AnfAlgo::KeepOrder(kg, stack_pushs[0], jump_node);
|
||||
if (node_iter != nodes.begin()) {
|
||||
AnfAlgo::KeepOrder(kg, *(node_iter - 1), *stack_pushs.rbegin());
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure StackPop is next to the jump_node.
|
||||
void KeepOrderForStackPop(const KernelGraphPtr &kg, const CNodePtr &pop, const CNodePtr &jump_node) {
|
||||
auto nodes = kg->execution_order();
|
||||
auto node_iter = std::find(nodes.cbegin(), nodes.cend(), jump_node);
|
||||
if (node_iter == nodes.cend()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find node: " << jump_node->DebugString();
|
||||
}
|
||||
// Insert between jump_node-1 and jump_node.
|
||||
if (node_iter != nodes.begin()) {
|
||||
CNodePtr node = *(node_iter - 1);
|
||||
AnfAlgo::KeepOrder(kg, node, pop);
|
||||
}
|
||||
AnfAlgo::KeepOrder(kg, pop, jump_node);
|
||||
}
|
||||
|
||||
void HandleCallSite(CallSite *call_site) {
|
||||
// Update last_monad_.
|
||||
last_monad_ = call_site.last_monad;
|
||||
last_monad_ = call_site->last_monad;
|
||||
|
||||
// The call/switch/switch_layer cnode.
|
||||
auto &cnode = call_site.cnode;
|
||||
auto &cnode = call_site->cnode;
|
||||
|
||||
// Get branches of the call_site.
|
||||
// for call, there is one branch;
|
||||
// for switch, the first one is true branch;
|
||||
// for switch_layer, the first one is 0 branch.
|
||||
auto &branches = call_site.callees;
|
||||
auto &branches = call_site->callees;
|
||||
|
||||
// Link arguments and find labels for branches.
|
||||
std::vector<KernelGraphPtr> graphes;
|
||||
|
@ -664,13 +810,14 @@ class AscendAutoMonadConverter {
|
|||
|
||||
// Create LabelGoto or LabelSwitch node.
|
||||
auto label_goto_switch = MakeLabelGotoSwitch(cnode, graphes, labels);
|
||||
call_site->conversion_cnode = label_goto_switch;
|
||||
|
||||
// Setup return label and output if required.
|
||||
if (call_site.return_label != kNoLabel) {
|
||||
auto label_node = LabelSet(call_site.return_label);
|
||||
AnfNodePtr output = call_site.out_param;
|
||||
if (call_site->return_label != kNoLabel) {
|
||||
auto label_node = LabelSet(call_site->return_label);
|
||||
AnfNodePtr output = call_site->out_param;
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
const bool is_single_call = call_site.label_indexes.empty();
|
||||
const bool is_single_call = call_site->label_indexes.empty();
|
||||
if (is_single_call) {
|
||||
// For single call, let output depend on the label node,
|
||||
// this ensures the return label is set before output is used.
|
||||
|
@ -688,7 +835,7 @@ class AscendAutoMonadConverter {
|
|||
}
|
||||
|
||||
// If no return label required, it should be a tail call.
|
||||
if (!call_site.tail) {
|
||||
if (!call_site->tail) {
|
||||
MS_LOG(EXCEPTION) << "Return label not set for non-tail call " << cnode->DebugString();
|
||||
}
|
||||
// For tail calls, replace origin call node with label_goto/label_switch.
|
||||
|
@ -697,8 +844,8 @@ class AscendAutoMonadConverter {
|
|||
}
|
||||
|
||||
// Assign label indexes to label parameters for a call site.
|
||||
void AssignLabelIndexes(const CallSite &call_site) {
|
||||
for (auto &[label_param, label_index] : call_site.label_indexes) {
|
||||
void AssignLabelIndexes(const CallSite *call_site) {
|
||||
for (auto &[label_param, label_index] : call_site->label_indexes) {
|
||||
auto index_value = GetIndexValueNode(label_index);
|
||||
auto assign = Assign(label_param, index_value, false, false, false);
|
||||
monad_ = UpdateState(GetMonad(), assign);
|
||||
|
@ -1020,6 +1167,50 @@ class AscendAutoMonadConverter {
|
|||
AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue(graphs), node);
|
||||
}
|
||||
|
||||
// Make a StackInit node.
|
||||
CNodePtr StackInit(const KernelGraphPtr &kg) {
|
||||
auto monad = AnfAlgo::MakeMonadValueNode(kg);
|
||||
auto stack_init = NewPrimitive(prim::kPrimStackInit);
|
||||
auto cnode = kg->NewCNode({stack_init, monad});
|
||||
AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue<int64_t>(0), cnode);
|
||||
cnode->set_abstract(monad->abstract());
|
||||
return cnode;
|
||||
}
|
||||
|
||||
// Make a StackDestroy node.
|
||||
CNodePtr StackDestroy(const KernelGraphPtr &kg) {
|
||||
auto monad = AnfAlgo::MakeMonadValueNode(kg);
|
||||
auto stack_destroy = NewPrimitive(prim::kPrimStackDestroy);
|
||||
auto cnode = kg->NewCNode({stack_destroy, monad});
|
||||
AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue<int64_t>(0), cnode);
|
||||
cnode->set_abstract(monad->abstract());
|
||||
return cnode;
|
||||
}
|
||||
|
||||
// Make a StackPush node.
|
||||
CNodePtr StackPush(const AnfNodePtr &input) {
|
||||
auto monad = AnfAlgo::MakeMonadValueNode(kernel_graph_);
|
||||
auto stack_push = NewPrimitive(prim::kPrimStackPush);
|
||||
auto cnode = kernel_graph_->NewCNode({stack_push, input, monad});
|
||||
AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue<int64_t>(0), cnode);
|
||||
auto op_name = std::to_string(kernel_graph_->graph_id()) + "_stack_push_" + std::to_string(name_index_++);
|
||||
AnfAlgo::SetNodeAttr(kAttrStackOpName, MakeValue(op_name), cnode);
|
||||
cnode->set_abstract(monad->abstract());
|
||||
return cnode;
|
||||
}
|
||||
|
||||
// Make a StackPop node.
|
||||
CNodePtr StackPop() {
|
||||
auto monad = AnfAlgo::MakeMonadValueNode(kernel_graph_);
|
||||
auto stack_pop = NewPrimitive(prim::kPrimStackPop);
|
||||
auto cnode = kernel_graph_->NewCNode({stack_pop, monad});
|
||||
AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue<int64_t>(0), cnode);
|
||||
auto op_name = std::to_string(kernel_graph_->graph_id()) + "_stack_pop_" + std::to_string(name_index_++);
|
||||
AnfAlgo::SetNodeAttr(kAttrStackOpName, MakeValue(op_name), cnode);
|
||||
cnode->set_abstract(monad->abstract()); // need to refresh output's abstract().
|
||||
return cnode;
|
||||
}
|
||||
|
||||
private:
|
||||
const KernelGraphPtr &kernel_graph_;
|
||||
AscendAutoMonadContext &context_;
|
||||
|
@ -1038,6 +1229,12 @@ class AscendAutoMonadConverter {
|
|||
|
||||
// Index value node cache for reuse.
|
||||
std::map<uint32_t, ValueNodePtr> index_nodes_;
|
||||
|
||||
// The index of stackops name.
|
||||
uint32_t name_index_;
|
||||
|
||||
// The flag which indicates to insert stackops.
|
||||
bool need_stackops_;
|
||||
};
|
||||
|
||||
constexpr size_t kAssignTargetIndex = 1;
|
||||
|
|
|
@ -116,6 +116,10 @@ constexpr auto kApplyProximalAdagradOpName = "ApplyProximalAdagrad ";
|
|||
constexpr auto kApplyProximalGradientDescentOpName = "ApplyProximalGradientDescent";
|
||||
constexpr auto kApplyRMSPropOpName = "ApplyRMSProp";
|
||||
constexpr auto kTransDataOpName = "TransData";
|
||||
constexpr auto kStackInitOpName = "StackInit";
|
||||
constexpr auto kStackPushOpName = "StackPush";
|
||||
constexpr auto kStackPopOpName = "StackPop";
|
||||
constexpr auto kStackDestroyOpName = "StackDestroy";
|
||||
constexpr auto kBNTrainingUpdateGradOpName = "BNTrainingUpdateGrad";
|
||||
constexpr auto kBNTrainingReduceGradOpName = "BNTrainingReduceGrad";
|
||||
constexpr auto kSquareSumV1OpName = "SquareSumV1";
|
||||
|
@ -380,6 +384,7 @@ constexpr auto kAttrRankSize = "rank_size";
|
|||
constexpr auto kAttrPadDimSize = "pad_dim_size";
|
||||
constexpr auto kAttrPaddings = "paddings";
|
||||
constexpr auto kAttrNumSegments = "num_segments";
|
||||
constexpr auto kAttrStackOpName = "stack_op_name";
|
||||
constexpr auto kAttrBegin = "begin";
|
||||
constexpr auto kAttrSize = "size";
|
||||
constexpr auto kAttrIsDynamicShape = "is_dynamic_shape";
|
||||
|
|
|
@ -105,6 +105,12 @@ inline const PrimitivePtr kPrimLabelGoto = std::make_shared<Primitive>("LabelGot
|
|||
inline const PrimitivePtr kPrimLabelSwitch = std::make_shared<Primitive>("LabelSwitch");
|
||||
inline const PrimitivePtr kPrimLabelSet = std::make_shared<Primitive>("LabelSet");
|
||||
|
||||
// Stack ops
|
||||
inline const PrimitivePtr kPrimStackInit = std::make_shared<Primitive>("StackInit");
|
||||
inline const PrimitivePtr kPrimStackDestroy = std::make_shared<Primitive>("StackDestroy");
|
||||
inline const PrimitivePtr kPrimStackPush = std::make_shared<Primitive>("StackPush");
|
||||
inline const PrimitivePtr kPrimStackPop = std::make_shared<Primitive>("StackPop");
|
||||
|
||||
// Arrays
|
||||
inline const PrimitivePtr kPrimBroadcastTo = std::make_shared<Primitive>("BroadcastTo");
|
||||
inline const PrimitivePtr kPrimScalarToArray = std::make_shared<Primitive>("scalar_to_array");
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -33,32 +33,6 @@ grad_by_list = C.GradOperation(get_by_list=True)
|
|||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
|
||||
def test_while_forward():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
|
||||
def construct(self, idx, end, x):
|
||||
while idx < end:
|
||||
part = x[idx, :, :]
|
||||
max_num = self.max(part)
|
||||
x[idx, :, 0:2] = max_num
|
||||
idx = idx + 1
|
||||
return x
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
net = MyWhileNet()
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(2), dtype=ms.int32)
|
||||
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||
graph_output = net(idx, end, x)
|
||||
#pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
def test_while_grad():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -95,6 +69,68 @@ def test_while_grad():
|
|||
assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
|
||||
assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
def test_while_with_const_param_grad():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mul = P.Mul()
|
||||
self.add = P.Add()
|
||||
|
||||
def construct(self, x, y):
|
||||
while x < y:
|
||||
z = self.mul(x, x)
|
||||
x = self.add(z, 1)
|
||||
return x
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor([1.1], dtype=ms.float32)
|
||||
end = Tensor([8.0], dtype=ms.float32)
|
||||
graph_output = net(idx, end)
|
||||
expect_one = np.array([1.14433983e+02], dtype=np.float32)
|
||||
expect_two = np.array([0], dtype=np.float32)
|
||||
assert np.allclose(graph_output[0].asnumpy(), expect_one, 0.0001, 0.0001)
|
||||
assert np.allclose(graph_output[1].asnumpy(), expect_two, 0.0001, 0.0001)
|
||||
|
||||
def test_while_with_variable_grad():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mul = P.Mul()
|
||||
self.add = P.Add()
|
||||
|
||||
def construct(self, x, y):
|
||||
while x < y:
|
||||
z = self.mul(x, x)
|
||||
x = self.add(z, y)
|
||||
return x
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor([1.1], dtype=ms.float32)
|
||||
end = Tensor([8.0], dtype=ms.float32)
|
||||
graph_output = net(idx, end)
|
||||
expect_one = np.array([2.20000005e+00], dtype=np.float32)
|
||||
expect_two = np.array([1.00000000e+00], dtype=np.float32)
|
||||
assert np.allclose(graph_output[0].asnumpy(), expect_one, 0.0001, 0.0001)
|
||||
assert np.allclose(graph_output[1].asnumpy(), expect_two, 0.0001, 0.0001)
|
||||
|
||||
def test_while_with_param_forward():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -153,7 +189,6 @@ def test_while_endless_case():
|
|||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
def test_while_with_param_grad():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -180,7 +215,6 @@ def test_while_with_param_grad():
|
|||
|
||||
def construct(self, a, b, c):
|
||||
return grad_by_list(self.net, self.weights)(a, b, c)
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
|
@ -188,10 +222,8 @@ def test_while_with_param_grad():
|
|||
end = Tensor(np.array(2), dtype=ms.int32)
|
||||
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||
graph_output = net(idx, end, x)
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
|
||||
expect = np.array([[[2, 2], [2, 2]], [[2, 2], [2, 2]]], dtype=np.int32)
|
||||
assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001)
|
||||
|
||||
def test_while_with_param_forward_with_const_branch():
|
||||
class MyWhileNet(nn.Cell):
|
||||
|
|
Loading…
Reference in New Issue