Merge pull request !37985 from lyqlola/master
This commit is contained in:
i-robot 2022-07-18 03:29:01 +00:00 committed by Gitee
commit fb7c04f5db
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 35 additions and 37 deletions

View File

@ -19,6 +19,7 @@
#include <string>
#include "backend/common/optimizer/common_backend_optimization.h"
#include "plugin/device/ascend/optimizer/ascend_backend_optimization.h"
#include "plugin/device/ascend/optimizer/ascend_comm_op_reuse.h"
#include "plugin/device/ascend/hal/hardware/ascend_utils.h"
#include "common/graph_kernel/adapter/graph_kernel_optimization.h"
#include "common/graph_kernel/adapter/expander.h"
@ -106,6 +107,7 @@ void AscendGraphOptimization::OptimizeGraphWithoutDeviceInfo(const KernelGraphPt
CheckControlFlowDynamicShape(graph);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
CommOpReuse(graph);
if (context_ptr->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
HandleControlFlow(NOT_NULL(graph));
}
@ -174,6 +176,30 @@ void AscendGraphOptimization::PostOptimization(const KernelGraphPtr &graph) cons
MS_LOG(INFO) << "Status record: end post optimization. graph id: " << graph->graph_id();
}
void AscendGraphOptimization::CommOpReuse(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto max_comm_op_reuse_num_env = common::GetEnv("MS_COMM_COMPILER_OPT");
if (!graph->is_graph_run_mode() || max_comm_op_reuse_num_env.empty()) {
return;
}
MS_LOG(INFO) << "Status record: start comm op reuse. graph id: " << graph->graph_id();
const uint32_t max_comm_op_reuse_num = IntToUint(std::stoi(max_comm_op_reuse_num_env));
MS_LOG(INFO) << "MAX_COMM_OP_REUSE_NUM: " << max_comm_op_reuse_num;
opt::AscendCommOpReuse comm_io_reuse(graph, max_comm_op_reuse_num);
comm_io_reuse.Run();
MS_LOG(INFO) << "Status record: end comm op reuse. graph id: " << graph->graph_id();
#ifdef ENABLE_DUMP_IR
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
if (save_graphs) {
std::string file_name = "hwopt_comm_reuse_after_graph_" + std::to_string(graph->graph_id()) + ".ir";
DumpIR(file_name, graph);
}
#endif
}
void AscendGraphOptimization::HardWareOptimization(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Status record: start hardware optimize. graph id: " << graph->graph_id();

View File

@ -53,6 +53,7 @@ class AscendGraphOptimization {
void PostOptimization(const KernelGraphPtr &graph) const;
// Graph Optimized level-3 interface
void CommOpReuse(const KernelGraphPtr &graph);
void IRFusionOptimization(const KernelGraphPtr &graph);
void UpdateRefOutputMap(const KernelGraphPtr &graph);
void AddGraphToManager(const NotNull<KernelGraphPtr> graph, const NotNull<FuncGraphManagerPtr> manager,

View File

@ -25,7 +25,6 @@
namespace mindspore {
namespace opt {
namespace ascend {
namespace {
template <class T>
std::string VecToString(const std::vector<T> &vec) {
@ -234,9 +233,9 @@ void AscendCommOpReuse::AnalyseCommOpReuse() {
KernelGraphPtr AscendCommOpReuse::CreateCommSubGraph(const CNodePtr &comm_op) {
MS_EXCEPTION_IF_NULL(comm_op);
MS_EXCEPTION_IF_ZERO("input size of comm_op " + comm_op->DebugString(), comm_op->size());
MS_EXCEPTION_IF_NULL(create_new_kernel_graph_);
// create sub graph
auto graph = create_new_kernel_graph_();
auto graph = std::make_shared<session::KernelGraph>();
graph->set_graph_id(comm_subgraph_sum_++);
MS_EXCEPTION_IF_NULL(graph);
auto sub_graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(sub_graph_inputs);
@ -300,6 +299,5 @@ void AscendCommOpReuse::ReplaceCommOpToCallNode() {
origin_comm_op->set_inputs(new_call_op_args);
}
}
} // namespace ascend
} // namespace opt
} // namespace mindspore

View File

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_COMM_OP_REUSE_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_COMM_OP_REUSE_H_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_ASCEND_COMM_OP_REUSE_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_ASCEND_COMM_OP_REUSE_H_
#include <functional>
#include <map>
@ -27,14 +27,10 @@
namespace mindspore {
namespace opt {
namespace ascend {
class AscendCommOpReuse {
public:
AscendCommOpReuse(const KernelGraphPtr &root_graph, std::function<KernelGraphPtr()> create_new_kernel_graph,
const uint32_t &max_comm_op_reuse_num)
: root_graph_(root_graph),
create_new_kernel_graph_(create_new_kernel_graph),
max_comm_op_reuse_num_(max_comm_op_reuse_num) {}
AscendCommOpReuse(const KernelGraphPtr &root_graph, const uint32_t &max_comm_op_reuse_num)
: root_graph_(root_graph), max_comm_op_reuse_num_(max_comm_op_reuse_num) {}
void Run();
private:
@ -50,12 +46,11 @@ class AscendCommOpReuse {
KernelGraphPtr root_graph_ = {};
std::vector<std::pair<CNodePtr, KernelGraphPtr>> all_comm_ops_ = {}; // use vector to keep order
std::map<CNodePtr, KernelGraphPtr> reused_comm_sub_graphs_ = {}; // origin comm op to reused comm sub graph
std::function<KernelGraphPtr()> create_new_kernel_graph_ = {};
const uint32_t max_comm_op_reuse_num_;
uint32_t comm_subgraph_sum_ = 50000;
uint32_t total_comm_op_reuse_num_ = 0;
};
} // namespace ascend
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_COMM_OP_REUSE_H_
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_ASCEND_COMM_OP_REUSE_H_

View File

@ -32,9 +32,6 @@
#include "kernel/common_utils.h"
#include "profiler/device/profiling.h"
#include "backend/common/optimizer/helper.h"
#ifdef ENABLE_D
#include "plugin/device/ascend/optimizer/ascend_comm_op_reuse.h"
#endif
#include "base/base_ref_utils.h"
#include "include/common/debug/dump_proto.h"
#ifdef ENABLE_DEBUGGER
@ -725,25 +722,6 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
}
#endif
#ifdef ENABLE_D
// Add comm op reuse for Ascend backend
auto max_comm_op_reuse_num_env = common::GetEnv("MS_COMM_COMPILER_OPT");
if (device_context->GetDeviceType() == device::DeviceType::kAscend && !max_comm_op_reuse_num_env.empty()) {
const uint32_t max_comm_op_reuse_num = IntToUint(std::stoi(max_comm_op_reuse_num_env));
MS_LOG(INFO) << "MAX_COMM_OP_REUSE_NUM: " << max_comm_op_reuse_num;
opt::ascend::AscendCommOpReuse comm_io_reuse(
graph, [this]() { return this->session_->NewKernelGraph(); }, max_comm_op_reuse_num);
comm_io_reuse.Run();
#ifdef ENABLE_DUMP_IR
if (save_graphs) {
std::string file_name = "hwopt_comm_reuse_after_graph_" + std::to_string(graph->graph_id()) + ".ir";
DumpIR(file_name, graph);
}
#endif
}
#endif
// Execute optimization pass.
device_context->kernel_executor_->OptimizeGraph(graph);