!2860 stream analysis for memory reuse

Merge pull request !2860 from gukecai/stream-for-memcpy
This commit is contained in:
mindspore-ci-bot 2020-07-07 17:36:59 +08:00 committed by Gitee
commit 0a375743d1
2 changed files with 374 additions and 0 deletions

View File

@ -48,6 +48,12 @@ void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr)
CheckResourceAssign(graph_ptr);
MS_LOG(INFO) << "After finish stream assign";
FindStreamRelations(graph_ptr);
PrintStreamRelations();
GetStreamRelations();
PrintStreamGroups();
FindEventRelations(graph_ptr);
// Get info for D Model
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
generator::IRModelUtil::GetInstance().set_event_num(resource_manager.get_cur_event_num());
@ -501,6 +507,8 @@ void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull<KernelGraphPt
auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_stream_id);
cnodes.emplace_back(recv);
cnodes.emplace_back(cur_cnode_ptr);
} else {
cnodes.emplace_back(cur_cnode_ptr);
}
pre_stream_id = cur_stream_id;
}
@ -910,7 +918,351 @@ void AscendStreamAssign::Reset() {
common_stream_map_.clear();
processed_streams_.clear();
need_first_active_streams_.clear();
stream_groups_.clear();
stream_relations_.clear();
event_map_.clear();
}
// section 10
bool AscendStreamAssign::IsVecExist(std::vector<uint32_t> *group) {
auto group_size = group->size();
if (group_size == 0) {
return false;
}
for (const auto &item : stream_groups_) {
if (item.size() < group->size()) {
continue;
}
bool flag = true;
for (size_t i = 0; i < group_size; i++) {
if (item[i] != group->at(i)) {
flag = false;
break;
}
}
if (flag) {
return true;
} else {
continue;
}
}
return false;
}
void AscendStreamAssign::DFS(uint32_t start, std::vector<uint32_t> *group) {
auto it = stream_relations_.find(start);
if (it == stream_relations_.end()) {
if (!IsVecExist(group)) {
stream_groups_.emplace_back(*group);
} else {
MS_LOG(WARNING) << "DFS should not print this log";
}
return;
}
vector<uint32_t> active_streams = stream_relations_[start];
for (const auto &item : active_streams) {
group->emplace_back(item);
DFS(item, group);
group->pop_back();
}
}
void AscendStreamAssign::GetStreamRelations() {
for (const auto &start : need_first_active_streams_) {
vector<uint32_t> group{start};
DFS(start, &group);
}
}
void AscendStreamAssign::FindStreamRelations(const NotNull<KernelGraphPtr> &graph_ptr) {
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
auto stream_num = resource_manager.get_cur_stream_num();
if (stream_num <= 1) {
return;
}
auto exe_orders = graph_ptr->execution_order();
for (size_t i = 0; i < exe_orders.size(); i++) {
auto cur_cnode = exe_orders[i];
auto name = AnfAlgo::GetCNodeName(cur_cnode);
if (name != kStreamSwitchOpName && name != kStreamActiveOpName) {
continue;
}
// support:streamswitch is begin of the stream
if (name == kStreamSwitchOpName) {
GetStreamSwitchStreamRelation(cur_cnode);
}
if (name == kStreamActiveOpName) {
GetStreamActiveStreamRelation(graph_ptr, i);
}
}
}
void AscendStreamAssign::GetStreamSwitchStreamRelation(const CNodePtr &node_ptr) {
MS_EXCEPTION_IF_NULL(node_ptr);
auto cur_stream_id = AnfAlgo::GetStreamId(node_ptr);
auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(node_ptr, kAttrTrueBranchStream);
if (true_stream_id <= cur_stream_id) {
MS_LOG(ERROR) << "StreamSwitch self stream id " << cur_stream_id
<< " is greater than true branch stream id:" << true_stream_id;
}
auto it = stream_relations_.find(cur_stream_id);
if (it == stream_relations_.end()) {
stream_relations_[cur_stream_id] = {true_stream_id};
} else {
auto iter =
std::find(stream_relations_[cur_stream_id].begin(), stream_relations_[cur_stream_id].end(), true_stream_id);
if (iter == stream_relations_[cur_stream_id].end()) {
stream_relations_[cur_stream_id].emplace_back(true_stream_id);
}
}
}
void AscendStreamAssign::GetStreamActiveStreamRelation(const NotNull<KernelGraphPtr> &graph_ptr, size_t index) {
StreamActiveKind kind = GetStreamActiveKind(graph_ptr, index);
if (kind == kInvalid) {
MS_LOG(INFO) << "Invalid streamActive kind";
return;
}
auto orders = graph_ptr->execution_order();
auto cur_cnode = orders[index];
auto cur_stream_id = AnfAlgo::GetStreamId(cur_cnode);
auto active_list = AnfAlgo::GetNodeAttr<vector<uint32_t>>(cur_cnode, kAttrActiveStreamList);
if (kind == kHead) {
uint32_t active_current_node = GetStreamByActivedStream(cur_stream_id);
if (active_current_node == kInvalidStreamId) {
MS_LOG(EXCEPTION) << "No stream to active streamactive stream";
}
for (const auto &item : active_list) {
if (item <= active_current_node) {
MS_LOG(WARNING) << "Actived stream is less than activing stream";
continue;
}
auto it =
std::find(stream_relations_[active_current_node].begin(), stream_relations_[active_current_node].end(), item);
if (it == stream_relations_[active_current_node].end()) {
stream_relations_[active_current_node].emplace_back(item);
}
}
}
if (kind == kMiddle) {
for (const auto &stream : active_list) {
if (stream <= cur_stream_id) {
MS_LOG(INFO) << "MIDDLE StreamActive active stream is less than self stream, no need deal";
} else {
MS_LOG(ERROR) << "MIDDLE StreamActive active stream is greater than self stream, should not be exit now";
}
}
}
if (kind == kTail) {
auto it = stream_relations_.find(cur_stream_id);
if (it == stream_relations_.end()) {
stream_relations_[cur_stream_id] = active_list;
} else {
for (const auto &stream : active_list) {
if (stream <= cur_stream_id) {
MS_LOG(WARNING) << "Actived stream is less than activing stream";
continue;
}
auto iter = std::find(stream_relations_[cur_stream_id].begin(), stream_relations_[cur_stream_id].end(), stream);
if (iter == stream_relations_[cur_stream_id].end()) {
stream_relations_[cur_stream_id].emplace_back(stream);
}
}
}
}
}
StreamActiveKind AscendStreamAssign::GetStreamActiveKind(const NotNull<KernelGraphPtr> &graph_ptr, size_t index) {
auto exe_orders = graph_ptr->execution_order();
if (index >= exe_orders.size()) {
MS_LOG(EXCEPTION) << "Invalid op index:" << index;
}
auto cur_cnode = exe_orders[index];
auto cur_stream_id = AnfAlgo::GetStreamId(cur_cnode);
if (AnfAlgo::GetCNodeName(cur_cnode) != kStreamActiveOpName) {
MS_LOG(EXCEPTION) << "Current node name is not StreamActive";
}
if (index == 0) {
return kInvalid;
}
if (index == exe_orders.size() - 1) {
return kInvalid;
}
uint32_t pre_stream_id = UINT32_MAX;
uint32_t next_stream_id = UINT32_MAX;
int32_t start = SizeToInt(index);
for (int32_t i = start; i >= 0; i--) {
auto cnode = exe_orders[IntToSize(i)];
auto name = AnfAlgo::GetCNodeName(cnode);
if (name == kSendOpName || name == kRecvOpName) {
continue;
}
pre_stream_id = AnfAlgo::GetStreamId(cnode);
break;
}
for (size_t i = index + 1; i < exe_orders.size(); i++) {
auto cnode = exe_orders[i];
auto name = AnfAlgo::GetCNodeName(cnode);
if (name == kSendOpName || name == kRecvOpName) {
continue;
}
next_stream_id = AnfAlgo::GetStreamId(cnode);
break;
}
// pre_stream_id = UINT32_MAX:means no node active current StreamActive
// next_stream_id = UINT32_MAX:means current StreamActive active no node
if (pre_stream_id == UINT32_MAX || next_stream_id == UINT32_MAX) {
return kInvalid;
}
if (cur_stream_id == pre_stream_id && cur_stream_id == next_stream_id) {
return kMiddle;
}
if (cur_stream_id == pre_stream_id) {
return kTail;
}
if (cur_stream_id == next_stream_id) {
return kHead;
}
return kInvalid;
}
uint32_t AscendStreamAssign::GetStreamByActivedStream(uint32_t actived_stream_id) {
if (stream_relations_.empty()) {
return kInvalidStreamId;
}
for (const auto &item : stream_relations_) {
auto it = std::find(item.second.begin(), item.second.end(), actived_stream_id);
if (it != item.second.end()) {
return item.first;
}
}
return kInvalidStreamId;
}
void AscendStreamAssign::PrintStreamRelations() {
MS_LOG(INFO) << "Stream relations size:" << stream_relations_.size();
for (const auto &item : stream_relations_) {
MS_LOG(INFO) << "Stream:" << item.first;
for (const auto &stream : item.second) {
MS_LOG(INFO) << "--actived stream id:" << stream;
}
}
}
void AscendStreamAssign::PrintStreamGroups() {
MS_LOG(INFO) << "Stream group size:" << stream_groups_.size();
for (const auto &item : stream_groups_) {
MS_LOG(INFO) << "Group:";
for (const auto &stream : item) {
MS_LOG(INFO) << "Stream id:" << stream;
}
}
}
// section 11
bool AscendStreamAssign::IsSatisfiedEvent(uint32_t send_stream_id, uint32_t recv_stream_id) const {
size_t send_group = 0;
size_t recv_group = 0;
bool send_flag = true;
bool recv_flag = true;
for (size_t i = 0; i < stream_groups_.size(); i++) {
auto group = stream_groups_[i];
if (send_flag) {
auto it = std::find(group.begin(), group.end(), send_stream_id);
if (it != group.end()) {
send_group = i;
send_flag = false;
}
}
if (recv_flag) {
auto it = std::find(group.begin(), group.end(), recv_stream_id);
if (it != group.end()) {
recv_group = i;
recv_flag = false;
}
}
}
if (!(send_flag || recv_flag)) {
return (send_group != recv_group);
}
return false;
}
void AscendStreamAssign::FindEventRelations(const NotNull<KernelGraphPtr> &graph_ptr) {
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
auto event_nums = resource_manager.get_cur_event_num();
if (event_nums == 0) {
return;
}
auto exe_orders = graph_ptr->execution_order();
// find all event info
for (size_t i = 0; i < exe_orders.size(); i++) {
auto cur_cnode = exe_orders[i];
auto name = AnfAlgo::GetCNodeName(cur_cnode);
if (name == kSendOpName) {
event_map_[cur_cnode] = {};
}
if (name == kRecvOpName) {
auto recv_event_id = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode, kAttrEventId);
for (auto &item : event_map_) {
auto send_event_id = AnfAlgo::GetNodeAttr<uint32_t>(item.first, kAttrEventId);
if (recv_event_id == send_event_id) {
item.second = cur_cnode;
break;
}
}
}
}
// delete useless event info
auto begin = event_map_.begin();
while (begin != event_map_.end()) {
auto send_stream_id = AnfAlgo::GetStreamId(begin->first);
auto recv_stream_id = AnfAlgo::GetStreamId(begin->second);
bool flag = IsSatisfiedEvent(send_stream_id, recv_stream_id);
if (!flag) {
begin = event_map_.erase(begin);
} else {
begin++;
}
}
MS_LOG(INFO) << "Satisfied event info";
for (const auto &item : event_map_) {
MS_LOG(INFO) << "Event_id:" << AnfAlgo::GetNodeAttr<uint32_t>(item.first, kAttrEventId);
}
}
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -94,6 +94,7 @@ class AscendResourceMng {
uint32_t cur_event_num_{0};
};
enum StreamActiveKind { kInvalid = 0, kHead, kMiddle, kTail };
class AscendStreamAssign {
public:
static AscendStreamAssign &GetInstance() {
@ -109,6 +110,8 @@ class AscendStreamAssign {
void GetWaitStreams(vector<uint32_t> *wait_active_stream_list);
CNodePtr CreateSendApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id);
CNodePtr CreateRecvApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id);
const std::vector<std::vector<uint32_t>> &get_stream_group() const { return stream_groups_; }
const std::map<CNodePtr, CNodePtr> &get_event_map() const { return event_map_; }
private:
AscendStreamAssign() = default;
@ -147,6 +150,20 @@ class AscendStreamAssign {
const CNodePtr &node);
void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams);
// function for memory resue
void GetStreamRelations();
void DFS(uint32_t start, std::vector<uint32_t> *group);
bool IsVecExist(std::vector<uint32_t> *group);
void FindStreamRelations(const NotNull<KernelGraphPtr> &graph_ptr);
void GetStreamSwitchStreamRelation(const CNodePtr &node_ptr);
void GetStreamActiveStreamRelation(const NotNull<KernelGraphPtr> &graph_ptr, size_t index);
StreamActiveKind GetStreamActiveKind(const NotNull<KernelGraphPtr> &graph_ptr, size_t index);
uint32_t GetStreamByActivedStream(uint32_t actived_stream_id);
void PrintStreamRelations();
void PrintStreamGroups();
void FindEventRelations(const NotNull<KernelGraphPtr> &graph_ptr);
bool IsSatisfiedEvent(uint32_t send_stream_id, uint32_t recv_stream_id) const;
bool independent_stream_activated_{false};
bool hcom_stream_activated_{false};
std::map<uint32_t, uint32_t> independent_stream_map_{};
@ -154,6 +171,11 @@ class AscendStreamAssign {
std::map<uint32_t, uint32_t> common_stream_map_{};
std::set<uint32_t> processed_streams_{};
std::vector<uint32_t> need_first_active_streams_{};
// attr for memory copy reuse
std::map<uint32_t, std::vector<uint32_t>> stream_relations_{};
std::vector<std::vector<uint32_t>> stream_groups_{};
std::map<CNodePtr, CNodePtr> event_map_;
// new policy end
};
} // namespace ascend