optimize execute order sort

This commit is contained in:
kswang 2020-04-15 11:12:14 +08:00
parent 7214c04114
commit 83eeac9310
7 changed files with 128 additions and 78 deletions

View File

@ -23,7 +23,7 @@ namespace ascend {
class AscendMemoryManager : public MemoryManager {
public:
AscendMemoryManager() = default;
virtual ~AscendMemoryManager() = default;
~AscendMemoryManager() override = default;
void MallocDeviceMemory() override;
void FreeDeviceMemory() override;

View File

@ -26,6 +26,8 @@ namespace ascend {
class AscendMemoryPool : public DynamicMemPoolBestFit {
public:
~AscendMemoryPool() override = default;
AscendMemoryPool(const AscendMemoryPool&) = delete;
AscendMemoryPool& operator=(const AscendMemoryPool&) = delete;
size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override;
bool FreeDeviceMem(const DeviceMemPtr& addr) override;
@ -51,13 +53,11 @@ class AscendMemoryPool : public DynamicMemPoolBestFit {
private:
AscendMemoryPool() = default;
AscendMemoryPool(const AscendMemoryPool&) = delete;
AscendMemoryPool& operator=(const AscendMemoryPool&) = delete;
bool has_malloc_{false};
uint8_t* device_mem_pool_base_{nullptr};
uint64_t device_mem_pool_size_{0};
size_t free_mem_size_;
size_t total_mem_size_;
size_t free_mem_size_{0};
size_t total_mem_size_{0};
};
} // namespace ascend
} // namespace device

View File

@ -858,6 +858,14 @@ bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) {
return false;
}
bool AnfRuntimeAlgorithm::IsAllReduceOp(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kAllReduceOpName) {
return true;
}
return false;
}
bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) {
auto kernel_name = AnfAlgo::GetCNodeName(node);
return kernel_name == kGetNextOpName;

View File

@ -176,6 +176,7 @@ class AnfRuntimeAlgorithm {
// get real input index for some tbe ops which input order is different between me and tbe impl
static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
static bool IsCommunicationOp(const AnfNodePtr &node);
static bool IsAllReduceOp(const AnfNodePtr &node);
static bool IsGetNext(const NotNull<AnfNodePtr> &node);
};
} // namespace session

View File

@ -50,90 +50,127 @@ std::vector<AnfNodePtr> KernelGraph::outputs() const {
}
void KernelGraph::SetExecOrderByDefault() {
BfsToUpdateNodeOutput();
std::stack<AnfNodePtr> seed_nodes;
UpdateNodeEdgeList(&seed_nodes);
execution_order_.clear();
std::queue<AnfNodePtr> allreduce_nodes;
std::queue<AnfNodePtr> zero_output_nodes;
std::unordered_set<AnfNodePtr> visited_nodes;
auto clear_output = [&zero_output_nodes, &allreduce_nodes, &visited_nodes, this](const AnfNodePtr &input) -> void {
if (node_output_num_[input] == 0 && visited_nodes.find(input) == visited_nodes.end()) {
MS_EXCEPTION_IF_NULL(input);
MS_LOG(DEBUG) << "Clear output num:" << input->DebugString();
(void)visited_nodes.insert(input);
if (input->isa<CNode>() && AnfAlgo::GetCNodeName(input) == kAllReduceOpName) {
allreduce_nodes.push(input);
} else {
zero_output_nodes.push(input);
}
}
};
zero_output_nodes.emplace(get_return());
while (!zero_output_nodes.empty() || !allreduce_nodes.empty()) {
AnfNodePtr node;
if (!zero_output_nodes.empty()) {
node = zero_output_nodes.front();
zero_output_nodes.pop();
} else {
node = allreduce_nodes.front();
allreduce_nodes.pop();
}
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
execution_order_.push_back(node->cast<CNodePtr>());
}
auto it = node_input_edges_.find(node);
if (it == node_input_edges_.end()) {
std::queue<AnfNodePtr> zero_input_nodes;
auto visit_node_descendant = [&visited_nodes, this](const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue) {
auto it = node_output_edges_.find(node);
if (it == node_output_edges_.end()) {
// value node and parameter has no input,no need to print log
if (node->isa<CNode>()) {
MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]";
}
continue;
return;
}
for (const auto &input_edge : it->second) {
if (node_output_num_.find(input_edge.first) == node_output_num_.end()) {
MS_EXCEPTION_IF_NULL(input_edge.first);
MS_LOG(EXCEPTION) << "Can't find node[" << input_edge.first->DebugString() << "]";
// visit all reduce node first, then other nodes
std::vector<AnfNodePtr> active_nodes;
for (const auto &output_edge : it->second) {
auto next_node = output_edge.first;
if (node_input_num_.find(next_node) == node_input_num_.end()) {
MS_EXCEPTION_IF_NULL(next_node);
MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]";
}
MS_EXCEPTION_IF_NULL(input_edge.first);
MS_LOG(DEBUG) << "Decrease input:" << input_edge.first->DebugString() << ",node:" << node->DebugString()
<< ",num: " << node_output_num_[input_edge.first] << ",decrease num:" << input_edge.second;
if (node_output_num_[input_edge.first] < input_edge.second) {
MS_LOG(EXCEPTION) << "Input node:" << input_edge.first->DebugString() << ",node_output_num"
<< node_output_num_[input_edge.first] << "depend edge:" << input_edge.second;
MS_EXCEPTION_IF_NULL(next_node);
MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString()
<< ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second;
if (node_input_num_[next_node] < output_edge.second) {
MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num"
<< node_input_num_[next_node] << ",depend edge:" << output_edge.second;
}
node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second;
// allreduce first
if (node_input_num_[next_node] == 0 && visited_nodes.find(next_node) == visited_nodes.end()) {
(void)visited_nodes.insert(next_node);
if (AnfAlgo::IsAllReduceOp(next_node)) {
MS_LOG(DEBUG) << "visit node:" << next_node->DebugString();
visit_queue->push(next_node);
} else {
active_nodes.emplace_back(next_node);
}
}
}
for (auto &node : active_nodes) {
MS_LOG(DEBUG) << "visit node:" << node->DebugString();
visit_queue->push(node);
}
};
AnfNodePtr last_allreduce_node = nullptr;
std::queue<AnfNodePtr> allreduce_descendants;
while (!seed_nodes.empty() || last_allreduce_node != nullptr) {
// seed nodes first, then visit last all reduce node descendant
if (seed_nodes.empty()) {
visit_node_descendant(last_allreduce_node, &allreduce_descendants);
last_allreduce_node = nullptr;
} else {
zero_input_nodes.push(seed_nodes.top());
seed_nodes.pop();
}
// all reduce node descendant first, then common queue
while (!zero_input_nodes.empty() || !allreduce_descendants.empty()) {
AnfNodePtr node = nullptr;
bool is_allreduce_descendant = false;
if (allreduce_descendants.empty()) {
node = zero_input_nodes.front();
zero_input_nodes.pop();
} else {
node = allreduce_descendants.front();
allreduce_descendants.pop();
is_allreduce_descendant = true;
}
// add execute node
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
execution_order_.push_back(node->cast<CNodePtr>());
}
// for all reduce node, visit last all reduce node descendant
if (AnfAlgo::IsAllReduceOp(node)) {
if (last_allreduce_node != nullptr) {
visit_node_descendant(last_allreduce_node, &allreduce_descendants);
}
last_allreduce_node = node;
} else if (is_allreduce_descendant) {
visit_node_descendant(node, &allreduce_descendants);
} else {
visit_node_descendant(node, &zero_input_nodes);
}
node_output_num_[input_edge.first] = node_output_num_[input_edge.first] - input_edge.second;
clear_output(input_edge.first);
}
}
CheckLoop();
std::reverse(execution_order_.begin(), execution_order_.end());
}
void KernelGraph::CheckLoop() {
std::map<AnfNodePtr, size_t> none_zero_output;
if (node_output_edges_.size() != node_output_num_.size()) {
MS_LOG(EXCEPTION) << "node_output_edges_ size :" << node_output_edges_.size()
<< "not equal to node_output_num_ size:" << node_output_num_.size();
std::map<AnfNodePtr, size_t> none_zero_nodes;
if (node_input_edges_.size() != node_input_num_.size()) {
MS_LOG(EXCEPTION) << "node_input_edges_ size :" << node_input_edges_.size()
<< "not equal to node_input_num_ size:" << node_input_num_.size();
}
for (auto &it : node_output_num_) {
for (auto &it : node_input_num_) {
MS_EXCEPTION_IF_NULL(it.first);
string str;
auto node_output_it = node_output_edges_.find(it.first);
if (node_output_it == node_output_edges_.end()) {
auto node_input_it = node_input_edges_.find(it.first);
if (node_input_it == node_input_edges_.end()) {
MS_LOG(EXCEPTION) << "Can't find node [" << it.first->DebugString() << "]";
}
for (const auto &output_edge : node_output_edges_[it.first]) {
MS_EXCEPTION_IF_NULL(output_edge.first);
str = str.append(output_edge.first->DebugString()).append("|");
for (const auto &input_edge : node_input_edges_[it.first]) {
MS_EXCEPTION_IF_NULL(input_edge.first);
str = str.append(input_edge.first->DebugString()).append("|");
}
if (it.second != 0) {
MS_LOG(WARNING) << "Node:" << it.first->DebugString() << ",outputs:" << str << ",output num:" << it.second;
none_zero_output[it.first] = it.second;
MS_LOG(WARNING) << "Node:" << it.first->DebugString() << ",inputs:" << str << ",input num:" << it.second;
none_zero_nodes[it.first] = it.second;
}
}
// if don't consider control depend and loop exit,a exception will be throw
if (!none_zero_output.empty()) {
MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_output.size();
if (!none_zero_nodes.empty()) {
MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size();
}
}
@ -346,12 +383,13 @@ void KernelGraph::AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input,
} else {
input_it->second.push_back(input_depend_edge);
}
// add the depend sum of node
auto depend_it = node_output_num_.find(input);
if (depend_it == node_output_num_.end()) {
node_output_num_[input] = 0;
// add node input depend num
auto depend_it = node_input_num_.find(node);
if (depend_it == node_input_num_.end()) {
node_input_num_[node] = depend_edge_num;
} else {
depend_it->second += depend_edge_num;
}
node_output_num_[input] += depend_edge_num;
}
std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
@ -429,9 +467,9 @@ bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<Anf
return true;
}
void KernelGraph::BfsToUpdateNodeOutput() {
void KernelGraph::UpdateNodeEdgeList(std::stack<AnfNodePtr> *seed_nodes) {
node_output_edges_.clear();
node_output_num_.clear();
node_input_num_.clear();
node_input_edges_.clear();
std::vector<AnfNodePtr> control_depends;
std::unordered_set<AnfNodePtr> visited_nodes;
@ -441,6 +479,11 @@ void KernelGraph::BfsToUpdateNodeOutput() {
auto node = que.front();
que.pop();
MS_EXCEPTION_IF_NULL(node);
if (node->isa<Parameter>() || node->isa<ValueNode>()) {
seed_nodes->push(node);
continue;
}
if (!node->isa<CNode>()) {
continue;
}
@ -454,10 +497,6 @@ void KernelGraph::BfsToUpdateNodeOutput() {
control_depends.push_back(input);
depend_edge_num = 0;
}
// the 2rd input of depend is no depend edge
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && input == cnode->input(kDependAttachNodeIndex)) {
depend_edge_num = 0;
}
PushNoVisitedNode(input, &que, &visited_nodes);
AddDependEdge(node, input, depend_edge_num);
}

View File

@ -22,6 +22,7 @@
#include <utility>
#include <string>
#include <queue>
#include <stack>
#include <map>
#include <unordered_set>
#include "ir/func_graph.h"
@ -93,8 +94,8 @@ class KernelGraph : public FuncGraph {
private:
// remove value node form graph
bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
// BFS to update all nodes' output
void BfsToUpdateNodeOutput();
// update node edge list
void UpdateNodeEdgeList(std::stack<AnfNodePtr> *seed_nodes);
// add node depend edge by data edge or control depend
void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num);
// handle control depend
@ -114,7 +115,7 @@ class KernelGraph : public FuncGraph {
std::unordered_map<tensor::TensorPtr, ValueNodePtr> tensor_to_value_node_map_;
// include all value nodes
std::unordered_set<ValueNodePtr> graph_value_nodes_;
std::unordered_map<AnfNodePtr, size_t> node_output_num_;
std::unordered_map<AnfNodePtr, size_t> node_input_num_;
std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_input_edges_;
// record map between ref final output anf with index and ref origin input with index
std::map<AnfWithOutIndex, AnfWithOutIndex> ref_out_in_map_;

View File

@ -135,4 +135,5 @@ def test_LSTM():
for epoch in range(num_epochs):
loss = train_network(train_features, train_labels)
losses.append(loss)
print("loss:", loss.asnumpy())
assert(losses[-1].asnumpy() < 0.01)