forked from mindspore-Ecosystem/mindspore
commit
f87983833d
|
@ -24,9 +24,8 @@ namespace device {
|
||||||
namespace memswap {
|
namespace memswap {
|
||||||
bool MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph, size_t swap_mem_size) {
|
bool MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph, size_t swap_mem_size) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
graph_manager_ = kernel_graph->manager();
|
|
||||||
MS_EXCEPTION_IF_NULL(graph_manager_);
|
|
||||||
execution_order_ = kernel_graph->execution_order();
|
execution_order_ = kernel_graph->execution_order();
|
||||||
|
kernel_graph_ = kernel_graph;
|
||||||
|
|
||||||
size_t kernel_index = 0;
|
size_t kernel_index = 0;
|
||||||
for (const auto &kernel : execution_order_) {
|
for (const auto &kernel : execution_order_) {
|
||||||
|
@ -177,7 +176,10 @@ bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeUsersMap &user_map = graph_manager_->node_users();
|
MS_EXCEPTION_IF_NULL(kernel_graph_);
|
||||||
|
const auto &graph_manager = kernel_graph_->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(graph_manager);
|
||||||
|
NodeUsersMap &user_map = graph_manager->node_users();
|
||||||
auto iter = user_map.find(kernel);
|
auto iter = user_map.find(kernel);
|
||||||
bool adjacent_with_communication_op = false;
|
bool adjacent_with_communication_op = false;
|
||||||
if (iter != user_map.end()) {
|
if (iter != user_map.end()) {
|
||||||
|
@ -190,7 +192,10 @@ bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
void MemSwapManager::SaveUserKernelTopoOrder() {
|
void MemSwapManager::SaveUserKernelTopoOrder() {
|
||||||
NodeUsersMap &user_map = graph_manager_->node_users();
|
MS_EXCEPTION_IF_NULL(kernel_graph_);
|
||||||
|
const auto &graph_manager = kernel_graph_->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(graph_manager);
|
||||||
|
NodeUsersMap &user_map = graph_manager->node_users();
|
||||||
for (const auto &kernel : execution_order_) {
|
for (const auto &kernel : execution_order_) {
|
||||||
auto iter = user_map.find(kernel);
|
auto iter = user_map.find(kernel);
|
||||||
if (iter == user_map.end()) {
|
if (iter == user_map.end()) {
|
||||||
|
|
|
@ -156,7 +156,7 @@ class MemSwapManager {
|
||||||
size_t distance_decay_step_;
|
size_t distance_decay_step_;
|
||||||
|
|
||||||
MemCopyManagerPtr mem_copy_manager_{nullptr};
|
MemCopyManagerPtr mem_copy_manager_{nullptr};
|
||||||
FuncGraphManagerPtr graph_manager_{nullptr};
|
const mindspore::session::KernelGraph *kernel_graph_{nullptr};
|
||||||
bool mem_swap_initialized_{false};
|
bool mem_swap_initialized_{false};
|
||||||
bool swap_info_already_set_{false};
|
bool swap_info_already_set_{false};
|
||||||
bool trigger_swap_{false};
|
bool trigger_swap_{false};
|
||||||
|
|
|
@ -37,9 +37,11 @@ class MergeAddN : public AnfVisitor {
|
||||||
public:
|
public:
|
||||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||||
Reset();
|
Reset();
|
||||||
optimizer_ = optimizer;
|
mng_ = optimizer->resource()->manager();
|
||||||
is_outer_ = true;
|
is_outer_ = true;
|
||||||
AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node);
|
AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node);
|
||||||
|
// do not hold this manager
|
||||||
|
mng_ = nullptr;
|
||||||
if (!is_match_ || node->func_graph() == nullptr) {
|
if (!is_match_ || node->func_graph() == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -104,8 +106,7 @@ class MergeAddN : public AnfVisitor {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_unique(const AnfNodePtr &node) {
|
bool is_unique(const AnfNodePtr &node) {
|
||||||
auto mng = optimizer_->resource()->manager();
|
auto &node_users = mng_->node_users();
|
||||||
auto &node_users = mng->node_users();
|
|
||||||
if (node_users.find(node) == node_users.end()) {
|
if (node_users.find(node) == node_users.end()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -124,7 +125,7 @@ class MergeAddN : public AnfVisitor {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
OptimizerPtr optimizer_{nullptr};
|
FuncGraphManagerPtr mng_{nullptr};
|
||||||
std::vector<AnfNodePtr> Xs_{}, Ys_{}, args_{};
|
std::vector<AnfNodePtr> Xs_{}, Ys_{}, args_{};
|
||||||
bool is_inner_{false}, is_outer_{false}, is_match_{false};
|
bool is_inner_{false}, is_outer_{false}, is_match_{false};
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue