forked from mindspore-Ecosystem/mindspore
!2860 stream analysis for memory reuse
Merge pull request !2860 from gukecai/stream-for-memcpy
This commit is contained in:
commit
0a375743d1
|
@ -48,6 +48,12 @@ void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr)
|
|||
CheckResourceAssign(graph_ptr);
|
||||
MS_LOG(INFO) << "After finish stream assign";
|
||||
|
||||
FindStreamRelations(graph_ptr);
|
||||
PrintStreamRelations();
|
||||
GetStreamRelations();
|
||||
PrintStreamGroups();
|
||||
FindEventRelations(graph_ptr);
|
||||
|
||||
// Get info for D Model
|
||||
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||
generator::IRModelUtil::GetInstance().set_event_num(resource_manager.get_cur_event_num());
|
||||
|
@ -501,6 +507,8 @@ void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull<KernelGraphPt
|
|||
auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_stream_id);
|
||||
cnodes.emplace_back(recv);
|
||||
cnodes.emplace_back(cur_cnode_ptr);
|
||||
} else {
|
||||
cnodes.emplace_back(cur_cnode_ptr);
|
||||
}
|
||||
pre_stream_id = cur_stream_id;
|
||||
}
|
||||
|
@ -910,7 +918,351 @@ void AscendStreamAssign::Reset() {
|
|||
common_stream_map_.clear();
|
||||
processed_streams_.clear();
|
||||
need_first_active_streams_.clear();
|
||||
stream_groups_.clear();
|
||||
stream_relations_.clear();
|
||||
event_map_.clear();
|
||||
}
|
||||
|
||||
// section 10
|
||||
bool AscendStreamAssign::IsVecExist(std::vector<uint32_t> *group) {
|
||||
auto group_size = group->size();
|
||||
if (group_size == 0) {
|
||||
return false;
|
||||
}
|
||||
for (const auto &item : stream_groups_) {
|
||||
if (item.size() < group->size()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool flag = true;
|
||||
for (size_t i = 0; i < group_size; i++) {
|
||||
if (item[i] != group->at(i)) {
|
||||
flag = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (flag) {
|
||||
return true;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void AscendStreamAssign::DFS(uint32_t start, std::vector<uint32_t> *group) {
|
||||
auto it = stream_relations_.find(start);
|
||||
if (it == stream_relations_.end()) {
|
||||
if (!IsVecExist(group)) {
|
||||
stream_groups_.emplace_back(*group);
|
||||
} else {
|
||||
MS_LOG(WARNING) << "DFS should not print this log";
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
vector<uint32_t> active_streams = stream_relations_[start];
|
||||
|
||||
for (const auto &item : active_streams) {
|
||||
group->emplace_back(item);
|
||||
DFS(item, group);
|
||||
group->pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
void AscendStreamAssign::GetStreamRelations() {
|
||||
for (const auto &start : need_first_active_streams_) {
|
||||
vector<uint32_t> group{start};
|
||||
DFS(start, &group);
|
||||
}
|
||||
}
|
||||
|
||||
void AscendStreamAssign::FindStreamRelations(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||
auto stream_num = resource_manager.get_cur_stream_num();
|
||||
if (stream_num <= 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto exe_orders = graph_ptr->execution_order();
|
||||
for (size_t i = 0; i < exe_orders.size(); i++) {
|
||||
auto cur_cnode = exe_orders[i];
|
||||
auto name = AnfAlgo::GetCNodeName(cur_cnode);
|
||||
if (name != kStreamSwitchOpName && name != kStreamActiveOpName) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// support:streamswitch is begin of the stream
|
||||
if (name == kStreamSwitchOpName) {
|
||||
GetStreamSwitchStreamRelation(cur_cnode);
|
||||
}
|
||||
|
||||
if (name == kStreamActiveOpName) {
|
||||
GetStreamActiveStreamRelation(graph_ptr, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AscendStreamAssign::GetStreamSwitchStreamRelation(const CNodePtr &node_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(node_ptr);
|
||||
auto cur_stream_id = AnfAlgo::GetStreamId(node_ptr);
|
||||
auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(node_ptr, kAttrTrueBranchStream);
|
||||
if (true_stream_id <= cur_stream_id) {
|
||||
MS_LOG(ERROR) << "StreamSwitch self stream id " << cur_stream_id
|
||||
<< " is greater than true branch stream id:" << true_stream_id;
|
||||
}
|
||||
auto it = stream_relations_.find(cur_stream_id);
|
||||
if (it == stream_relations_.end()) {
|
||||
stream_relations_[cur_stream_id] = {true_stream_id};
|
||||
} else {
|
||||
auto iter =
|
||||
std::find(stream_relations_[cur_stream_id].begin(), stream_relations_[cur_stream_id].end(), true_stream_id);
|
||||
if (iter == stream_relations_[cur_stream_id].end()) {
|
||||
stream_relations_[cur_stream_id].emplace_back(true_stream_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AscendStreamAssign::GetStreamActiveStreamRelation(const NotNull<KernelGraphPtr> &graph_ptr, size_t index) {
|
||||
StreamActiveKind kind = GetStreamActiveKind(graph_ptr, index);
|
||||
if (kind == kInvalid) {
|
||||
MS_LOG(INFO) << "Invalid streamActive kind";
|
||||
return;
|
||||
}
|
||||
|
||||
auto orders = graph_ptr->execution_order();
|
||||
auto cur_cnode = orders[index];
|
||||
auto cur_stream_id = AnfAlgo::GetStreamId(cur_cnode);
|
||||
auto active_list = AnfAlgo::GetNodeAttr<vector<uint32_t>>(cur_cnode, kAttrActiveStreamList);
|
||||
if (kind == kHead) {
|
||||
uint32_t active_current_node = GetStreamByActivedStream(cur_stream_id);
|
||||
if (active_current_node == kInvalidStreamId) {
|
||||
MS_LOG(EXCEPTION) << "No stream to active streamactive stream";
|
||||
}
|
||||
|
||||
for (const auto &item : active_list) {
|
||||
if (item <= active_current_node) {
|
||||
MS_LOG(WARNING) << "Actived stream is less than activing stream";
|
||||
continue;
|
||||
}
|
||||
auto it =
|
||||
std::find(stream_relations_[active_current_node].begin(), stream_relations_[active_current_node].end(), item);
|
||||
if (it == stream_relations_[active_current_node].end()) {
|
||||
stream_relations_[active_current_node].emplace_back(item);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (kind == kMiddle) {
|
||||
for (const auto &stream : active_list) {
|
||||
if (stream <= cur_stream_id) {
|
||||
MS_LOG(INFO) << "MIDDLE StreamActive active stream is less than self stream, no need deal";
|
||||
} else {
|
||||
MS_LOG(ERROR) << "MIDDLE StreamActive active stream is greater than self stream, should not be exit now";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (kind == kTail) {
|
||||
auto it = stream_relations_.find(cur_stream_id);
|
||||
if (it == stream_relations_.end()) {
|
||||
stream_relations_[cur_stream_id] = active_list;
|
||||
} else {
|
||||
for (const auto &stream : active_list) {
|
||||
if (stream <= cur_stream_id) {
|
||||
MS_LOG(WARNING) << "Actived stream is less than activing stream";
|
||||
continue;
|
||||
}
|
||||
auto iter = std::find(stream_relations_[cur_stream_id].begin(), stream_relations_[cur_stream_id].end(), stream);
|
||||
if (iter == stream_relations_[cur_stream_id].end()) {
|
||||
stream_relations_[cur_stream_id].emplace_back(stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
StreamActiveKind AscendStreamAssign::GetStreamActiveKind(const NotNull<KernelGraphPtr> &graph_ptr, size_t index) {
|
||||
auto exe_orders = graph_ptr->execution_order();
|
||||
if (index >= exe_orders.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid op index:" << index;
|
||||
}
|
||||
|
||||
auto cur_cnode = exe_orders[index];
|
||||
auto cur_stream_id = AnfAlgo::GetStreamId(cur_cnode);
|
||||
if (AnfAlgo::GetCNodeName(cur_cnode) != kStreamActiveOpName) {
|
||||
MS_LOG(EXCEPTION) << "Current node name is not StreamActive";
|
||||
}
|
||||
|
||||
if (index == 0) {
|
||||
return kInvalid;
|
||||
}
|
||||
|
||||
if (index == exe_orders.size() - 1) {
|
||||
return kInvalid;
|
||||
}
|
||||
|
||||
uint32_t pre_stream_id = UINT32_MAX;
|
||||
uint32_t next_stream_id = UINT32_MAX;
|
||||
int32_t start = SizeToInt(index);
|
||||
for (int32_t i = start; i >= 0; i--) {
|
||||
auto cnode = exe_orders[IntToSize(i)];
|
||||
auto name = AnfAlgo::GetCNodeName(cnode);
|
||||
if (name == kSendOpName || name == kRecvOpName) {
|
||||
continue;
|
||||
}
|
||||
|
||||
pre_stream_id = AnfAlgo::GetStreamId(cnode);
|
||||
break;
|
||||
}
|
||||
|
||||
for (size_t i = index + 1; i < exe_orders.size(); i++) {
|
||||
auto cnode = exe_orders[i];
|
||||
auto name = AnfAlgo::GetCNodeName(cnode);
|
||||
if (name == kSendOpName || name == kRecvOpName) {
|
||||
continue;
|
||||
}
|
||||
|
||||
next_stream_id = AnfAlgo::GetStreamId(cnode);
|
||||
break;
|
||||
}
|
||||
|
||||
// pre_stream_id = UINT32_MAX:means no node active current StreamActive
|
||||
// next_stream_id = UINT32_MAX:means current StreamActive active no node
|
||||
if (pre_stream_id == UINT32_MAX || next_stream_id == UINT32_MAX) {
|
||||
return kInvalid;
|
||||
}
|
||||
|
||||
if (cur_stream_id == pre_stream_id && cur_stream_id == next_stream_id) {
|
||||
return kMiddle;
|
||||
}
|
||||
|
||||
if (cur_stream_id == pre_stream_id) {
|
||||
return kTail;
|
||||
}
|
||||
|
||||
if (cur_stream_id == next_stream_id) {
|
||||
return kHead;
|
||||
}
|
||||
|
||||
return kInvalid;
|
||||
}
|
||||
|
||||
uint32_t AscendStreamAssign::GetStreamByActivedStream(uint32_t actived_stream_id) {
|
||||
if (stream_relations_.empty()) {
|
||||
return kInvalidStreamId;
|
||||
}
|
||||
|
||||
for (const auto &item : stream_relations_) {
|
||||
auto it = std::find(item.second.begin(), item.second.end(), actived_stream_id);
|
||||
if (it != item.second.end()) {
|
||||
return item.first;
|
||||
}
|
||||
}
|
||||
|
||||
return kInvalidStreamId;
|
||||
}
|
||||
|
||||
void AscendStreamAssign::PrintStreamRelations() {
|
||||
MS_LOG(INFO) << "Stream relations size:" << stream_relations_.size();
|
||||
for (const auto &item : stream_relations_) {
|
||||
MS_LOG(INFO) << "Stream:" << item.first;
|
||||
for (const auto &stream : item.second) {
|
||||
MS_LOG(INFO) << "--actived stream id:" << stream;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AscendStreamAssign::PrintStreamGroups() {
|
||||
MS_LOG(INFO) << "Stream group size:" << stream_groups_.size();
|
||||
for (const auto &item : stream_groups_) {
|
||||
MS_LOG(INFO) << "Group:";
|
||||
for (const auto &stream : item) {
|
||||
MS_LOG(INFO) << "Stream id:" << stream;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// section 11
|
||||
bool AscendStreamAssign::IsSatisfiedEvent(uint32_t send_stream_id, uint32_t recv_stream_id) const {
|
||||
size_t send_group = 0;
|
||||
size_t recv_group = 0;
|
||||
bool send_flag = true;
|
||||
bool recv_flag = true;
|
||||
for (size_t i = 0; i < stream_groups_.size(); i++) {
|
||||
auto group = stream_groups_[i];
|
||||
if (send_flag) {
|
||||
auto it = std::find(group.begin(), group.end(), send_stream_id);
|
||||
if (it != group.end()) {
|
||||
send_group = i;
|
||||
send_flag = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (recv_flag) {
|
||||
auto it = std::find(group.begin(), group.end(), recv_stream_id);
|
||||
if (it != group.end()) {
|
||||
recv_group = i;
|
||||
recv_flag = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!(send_flag || recv_flag)) {
|
||||
return (send_group != recv_group);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void AscendStreamAssign::FindEventRelations(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||
auto event_nums = resource_manager.get_cur_event_num();
|
||||
if (event_nums == 0) {
|
||||
return;
|
||||
}
|
||||
auto exe_orders = graph_ptr->execution_order();
|
||||
// find all event info
|
||||
for (size_t i = 0; i < exe_orders.size(); i++) {
|
||||
auto cur_cnode = exe_orders[i];
|
||||
auto name = AnfAlgo::GetCNodeName(cur_cnode);
|
||||
if (name == kSendOpName) {
|
||||
event_map_[cur_cnode] = {};
|
||||
}
|
||||
|
||||
if (name == kRecvOpName) {
|
||||
auto recv_event_id = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode, kAttrEventId);
|
||||
for (auto &item : event_map_) {
|
||||
auto send_event_id = AnfAlgo::GetNodeAttr<uint32_t>(item.first, kAttrEventId);
|
||||
if (recv_event_id == send_event_id) {
|
||||
item.second = cur_cnode;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// delete useless event info
|
||||
auto begin = event_map_.begin();
|
||||
while (begin != event_map_.end()) {
|
||||
auto send_stream_id = AnfAlgo::GetStreamId(begin->first);
|
||||
auto recv_stream_id = AnfAlgo::GetStreamId(begin->second);
|
||||
bool flag = IsSatisfiedEvent(send_stream_id, recv_stream_id);
|
||||
if (!flag) {
|
||||
begin = event_map_.erase(begin);
|
||||
} else {
|
||||
begin++;
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Satisfied event info";
|
||||
for (const auto &item : event_map_) {
|
||||
MS_LOG(INFO) << "Event_id:" << AnfAlgo::GetNodeAttr<uint32_t>(item.first, kAttrEventId);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -94,6 +94,7 @@ class AscendResourceMng {
|
|||
uint32_t cur_event_num_{0};
|
||||
};
|
||||
|
||||
enum StreamActiveKind { kInvalid = 0, kHead, kMiddle, kTail };
|
||||
class AscendStreamAssign {
|
||||
public:
|
||||
static AscendStreamAssign &GetInstance() {
|
||||
|
@ -109,6 +110,8 @@ class AscendStreamAssign {
|
|||
void GetWaitStreams(vector<uint32_t> *wait_active_stream_list);
|
||||
CNodePtr CreateSendApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id);
|
||||
CNodePtr CreateRecvApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id);
|
||||
const std::vector<std::vector<uint32_t>> &get_stream_group() const { return stream_groups_; }
|
||||
const std::map<CNodePtr, CNodePtr> &get_event_map() const { return event_map_; }
|
||||
|
||||
private:
|
||||
AscendStreamAssign() = default;
|
||||
|
@ -147,6 +150,20 @@ class AscendStreamAssign {
|
|||
const CNodePtr &node);
|
||||
void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams);
|
||||
|
||||
// function for memory resue
|
||||
void GetStreamRelations();
|
||||
void DFS(uint32_t start, std::vector<uint32_t> *group);
|
||||
bool IsVecExist(std::vector<uint32_t> *group);
|
||||
void FindStreamRelations(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||
void GetStreamSwitchStreamRelation(const CNodePtr &node_ptr);
|
||||
void GetStreamActiveStreamRelation(const NotNull<KernelGraphPtr> &graph_ptr, size_t index);
|
||||
StreamActiveKind GetStreamActiveKind(const NotNull<KernelGraphPtr> &graph_ptr, size_t index);
|
||||
uint32_t GetStreamByActivedStream(uint32_t actived_stream_id);
|
||||
void PrintStreamRelations();
|
||||
void PrintStreamGroups();
|
||||
void FindEventRelations(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||
bool IsSatisfiedEvent(uint32_t send_stream_id, uint32_t recv_stream_id) const;
|
||||
|
||||
bool independent_stream_activated_{false};
|
||||
bool hcom_stream_activated_{false};
|
||||
std::map<uint32_t, uint32_t> independent_stream_map_{};
|
||||
|
@ -154,6 +171,11 @@ class AscendStreamAssign {
|
|||
std::map<uint32_t, uint32_t> common_stream_map_{};
|
||||
std::set<uint32_t> processed_streams_{};
|
||||
std::vector<uint32_t> need_first_active_streams_{};
|
||||
|
||||
// attr for memory copy reuse
|
||||
std::map<uint32_t, std::vector<uint32_t>> stream_relations_{};
|
||||
std::vector<std::vector<uint32_t>> stream_groups_{};
|
||||
std::map<CNodePtr, CNodePtr> event_map_;
|
||||
// new policy end
|
||||
};
|
||||
} // namespace ascend
|
||||
|
|
Loading…
Reference in New Issue