Merge pull request !31348 from tanghuikang/clean_code
This commit is contained in:
i-robot 2022-03-17 01:20:59 +00:00 committed by Gitee
commit 35d1782023
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
23 changed files with 56 additions and 61 deletions

View File

@ -20,6 +20,7 @@
namespace mindspore {
namespace opt {
constexpr const int64_t kFusionGap = 2;
bool AdjustDependForParallelOptimizerRecomputeAllGather::Run(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
mindspore::HashMap<int64_t, bool> forward_allgather_recompute_value_in_fusion_group;
@ -81,12 +82,14 @@ void AdjustDependForParallelOptimizerRecomputeAllGather::IncreaseAllgatherFusion
MS_LOG(WARNING) << "Increase the duplicated allgather fusion id";
for (auto &adjust_node : parallel_optimizer_recompute_first_fusion_allgathers) {
int64_t current_fusion_id = common::AnfAlgo::GetNodeAttr<int64_t>(adjust_node, kAttrFusion);
int64_t destination_fusion_id = current_fusion_id + unrecompute_max_fusion_id - recompute_min_fusion_id + 2;
int64_t destination_fusion_id =
current_fusion_id + unrecompute_max_fusion_id - recompute_min_fusion_id + kFusionGap;
common::AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(destination_fusion_id), adjust_node);
}
for (auto &adjust_node : parallel_optimizer_recompute_allgathers) {
int64_t current_fusion_id = common::AnfAlgo::GetNodeAttr<int64_t>(adjust_node, kAttrFusion);
int64_t destination_fusion_id = current_fusion_id + unrecompute_max_fusion_id - recompute_min_fusion_id + 2;
int64_t destination_fusion_id =
current_fusion_id + unrecompute_max_fusion_id - recompute_min_fusion_id + kFusionGap;
common::AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(destination_fusion_id), adjust_node);
}
}

View File

@ -71,7 +71,7 @@ Status SomasSolverPre::AddContiguousInfoInMultiMaps(const vector<vector<size_t>>
for (size_t i = 0; i < aux.size() - 1; i++) {
auto index1 = aux[i];
auto index2 = aux[i + 1];
if (CheckTensors(pTensors, index1, index2) == FAILED) {
if (CheckTensors(pTensors, SizeToUint(index1), SizeToUint(index2)) == FAILED) {
return FAILED;
}
for (size_t sol = 0; sol < vecTensorsMap->size(); sol++) {

View File

@ -25,6 +25,7 @@
#include "kernel/common_utils.h"
namespace mindspore {
namespace kernel {
constexpr size_t kJsonSuffixLength = 5;
namespace {
bool CheckHash(const std::string &json_file, const std::string &bin_file, const nlohmann::json &js) {
if (js.find("sha256") == js.end()) {
@ -108,7 +109,7 @@ bool KernelPack::ReadFromJsonFile(const std::string &json_f, const std::string &
}
if (processor == kProcessorCuda) {
std::string bin_f = json_f.substr(0, json_f.length() - 5) + ".ptx";
std::string bin_f = json_f.substr(0, json_f.length() - kJsonSuffixLength) + ".ptx";
std::ifstream kernelbin(bin_f);
if (!kernelbin.is_open()) {
MS_LOG(ERROR) << "read kernel ptx file error, please check kernelmeta.";
@ -140,7 +141,7 @@ bool KernelPack::ReadFromJsonFile(const std::string &json_f, const std::string &
}
std::string binfile_suffix = js["binFileSuffix"];
std::string bin_f = json_f.substr(0, json_f.length() - 5) + binfile_suffix;
std::string bin_f = json_f.substr(0, json_f.length() - kJsonSuffixLength) + binfile_suffix;
if (binfile_suffix == ".so") {
// change "xx/xx.so" -> "xx/libxx.so"
auto sp = bin_f.rfind('/');
@ -234,7 +235,7 @@ bool KernelPack::LoadKernelMeta(const std::string &json_f) {
}
ParseKernelJson(js);
std::string bin_f = json_f.substr(0, json_f.length() - 5) + kernel_json_info_.bin_file_suffix;
std::string bin_f = json_f.substr(0, json_f.length() - kJsonSuffixLength) + kernel_json_info_.bin_file_suffix;
if (kernel_json_info_.bin_file_suffix == ".so") {
// change "xx/xx.so" -> "xx/libxx.so"
auto sp = bin_f.rfind('/');

View File

@ -187,7 +187,7 @@ class KernelMod {
explicit KernelMod(const AnfNodePtr &anf_node_ptr) : anf_node_(anf_node_ptr) {}
virtual ~KernelMod() = default;
bool Launch(const KernelLaunchInfo &kernel_launch_address, void *stream_ptr) {
bool LaunchKernel(const KernelLaunchInfo &kernel_launch_address, void *stream_ptr) {
return Launch(kernel_launch_address.inputs_, kernel_launch_address.workspaces_, kernel_launch_address.outputs_,
stream_ptr);
}

View File

@ -365,7 +365,7 @@ bool AscendKernelRuntime::Init() {
return true;
}
bool AscendKernelRuntime::LoadData(const session::KernelGraph & /* graph */) {
bool AscendKernelRuntime::LoadData(const session::KernelGraph &) {
#ifdef ENABLE_DEBUGGER
MS_LOG(INFO) << "Start load step";
MS_EXCEPTION_IF_NULL(debugger_);

View File

@ -81,7 +81,7 @@ void HcclDynamicKernel::StaticShapeExecute() {
MS_EXCEPTION_IF_NULL(kernel_mod);
KernelLaunchInfo kernel_launch_info;
KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_launch_info);
kernel_mod->Launch(kernel_launch_info, stream_);
kernel_mod->LaunchKernel(kernel_launch_info, stream_);
}
void HcclDynamicKernel::Execute() {

View File

@ -31,7 +31,7 @@ ProfilerTask::ProfilerTask(const ModelContext &model_context, const std::shared_
stream_ = stream_list[stream_id];
}
ProfilerTask::~ProfilerTask() {}
ProfilerTask::~ProfilerTask() { stream_ = nullptr; }
void ProfilerTask::Distribute() {
MS_LOG(INFO) << "ProfilerTask Distribute start.";

View File

@ -298,8 +298,6 @@ void AscendDeviceContext::Destroy() {
graph_event_.clear();
rank_id_ = 0;
if (runtime_instance_) {
// TODO(lzlang): Destroy runtime instance after fully support MindRT, otherwise runtime will be destructed
// repeatedly. runtime_instance_->ReleaseDeviceRes();
runtime_instance_ = nullptr;
}
AscendGraphOptimization::GetInstance().Reset();

View File

@ -40,17 +40,6 @@ static constexpr const char *kHcclAlgoOption = "HCCL_algorithm";
return HcclResult::HCCL_E_RESERVED; \
}
#define CHECK_EXCUTION_MODE() \
do { \
auto hccl_mode = GetCurrentHcclMode(); \
if (hccl_mode != hccl_mode_) { \
MS_LOG(EXCEPTION) << "HCCL is initialized in " << GetHcclModeString(hccl_mode_) \
<< " but current execution mode is " << GetHcclModeString(hccl_mode) \
<< ". Please set the execution mode before HCCL init(), and then do not " \
"change it in the subsequent script"; \
} \
} while (0)
static std::map<std::string, std::string> GenHcclOptions(uint32_t device_id, std::string_view rank_id,
std::string_view rank_file) {
auto env_deploy_mode = mindspore::common::GetEnv(kHcclDeployModeEnv);
@ -159,6 +148,16 @@ HcclMode HcclAdapter::GetCurrentHcclMode() const {
}
}
void HcclAdapter::CheckExcutionMode() const {
auto hccl_mode = GetCurrentHcclMode();
if (hccl_mode != hccl_mode_) {
MS_LOG(EXCEPTION) << "HCCL is initialized in " << GetHcclModeString(hccl_mode_) << " but current execution mode is "
<< GetHcclModeString(hccl_mode)
<< ". Please set the execution mode before HCCL init(), and then do not change it in the "
"subsequent script";
}
}
std::string HcclAdapter::GetHcclModeString(HcclMode hccl_mode) {
static std::map<HcclMode, std::string> kHcclModeString = {
{HcclMode::kGraph, "GRAPH_MODE"},
@ -307,14 +306,14 @@ std::string HcclAdapter::GetHcclType(const AnfNodePtr &node) {
HcclResult HcclAdapter::HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root,
aclrtStream stream) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(launch_hccl_broadcast_);
return launch_hccl_broadcast_(buf, count, dataType, root, hccl_comm_, stream);
}
HcclResult HcclAdapter::HcclAllReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType,
HcclReduceOp op, aclrtStream stream, const std::string &group) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(launch_hccl_all_reduce_);
auto hccl_comm = GetHcomm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
@ -323,7 +322,7 @@ HcclResult HcclAdapter::HcclAllReduce(void *send_buf, void *recv_buf, uint64_t c
HcclResult HcclAdapter::HcclReduceScatter(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType,
HcclReduceOp op, aclrtStream stream, const std::string &group) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(launch_hccl_reduce_scatter_);
auto hccl_comm = GetHcomm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
@ -332,7 +331,7 @@ HcclResult HcclAdapter::HcclReduceScatter(void *send_buf, void *recv_buf, uint64
HcclResult HcclAdapter::HcclAllGather(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType,
aclrtStream stream, const std::string &group) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(launch_hccl_all_gather_);
auto hccl_comm = GetHcomm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
@ -341,7 +340,7 @@ HcclResult HcclAdapter::HcclAllGather(void *send_buf, void *recv_buf, uint64_t c
HcclResult HcclAdapter::HcclSend(void *send_buf, uint64_t count, HcclDataType dataType, uint32_t destRank,
aclrtStream stream, const std::string &group) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(launch_hccl_send_);
auto hccl_comm = GetHcomm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
@ -350,7 +349,7 @@ HcclResult HcclAdapter::HcclSend(void *send_buf, uint64_t count, HcclDataType da
HcclResult HcclAdapter::HcclRecv(void *recv_buf, uint64_t count, HcclDataType dataType, uint32_t srcRank,
aclrtStream stream, const std::string &group) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(launch_hccl_recv_);
auto hccl_comm = GetHcomm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
@ -474,7 +473,7 @@ bool HcclAdapter::FinalizeHcclComm() {
}
HcclResult HcclAdapter::HcclCreateGroup(const std::string &group, uint32_t rank_num, uint32_t *rank_ids) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(hccl_create_group_);
return hccl_create_group_(group.c_str(), rank_num, rank_ids);
}
@ -485,25 +484,25 @@ HcclResult HcclAdapter::HcclDestroyGroup(const std::string &group) const {
}
HcclResult HcclAdapter::HcclGetRankId(uint32_t *rank_id) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(single_op_hccl_get_rank_id_);
return single_op_hccl_get_rank_id_(hccl_comm_, rank_id);
}
HcclResult HcclAdapter::HcclGetRankSize(uint32_t *rank_size) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(single_op_hccl_get_rank_size_);
return single_op_hccl_get_rank_size_(hccl_comm_, rank_size);
}
HcclResult HcclAdapter::HcclGetRankId(const std::string &group, uint32_t *rank_id) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(hccl_get_rank_id_);
return hccl_get_rank_id_(group.c_str(), rank_id);
}
HcclResult HcclAdapter::HcclGetRankSize(const std::string &group, uint32_t *rank_size) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(hccl_get_rank_size_);
return hccl_get_rank_size_(group.c_str(), rank_size);
}
@ -537,13 +536,13 @@ bool HcclAdapter::FinalizeHcclExec() {
}
HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(hccl_exec_enqueue_op_);
return hccl_exec_enqueue_op_(op_info, callback);
}
HcclResult HcclAdapter::HcclExecAllToAllv(const ::HcomAllToAllVParams &params, const HExecCallBack &callback) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(hccl_exec_enqueue_all_to_all_v_);
return hccl_exec_enqueue_all_to_all_v_(params, callback);
}

View File

@ -106,6 +106,7 @@ class HcclAdapter {
bool FinalizeHcclExec();
HcclMode GetCurrentHcclMode() const;
void CheckExcutionMode() const;
static std::string GetHcclModeString(HcclMode hccl_mode);
void *plugin_handle_ = nullptr;

View File

@ -28,8 +28,8 @@ AssignKernel::AssignKernel() {}
AssignKernel::~AssignKernel() {}
bool AssignKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> & /* workspace */,
const std::vector<AddressPtr> & /*outputs*/, void *stream_ptr) {
bool AssignKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *stream_ptr) {
if (inputs.size() != 2) {
MS_LOG(ERROR) << "inputs size is not two";
return false;

View File

@ -43,8 +43,8 @@ bool LabelGotoKernel::Init(const AnfNodePtr &anf_node) {
return true;
}
bool LabelGotoKernel::Launch(const std::vector<AddressPtr> & /*inputs*/, const std::vector<AddressPtr> & /*workspace*/,
const std::vector<AddressPtr> & /*outputs*/, void * /*stream_ptr*/) {
bool LabelGotoKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *) {
MS_LOG(INFO) << "LabelGotoKernel launch";
return true;
}

View File

@ -50,9 +50,8 @@ bool ProfilingKernelMod::Init(const AnfNodePtr &anf_node) {
return true;
}
bool ProfilingKernelMod::Launch(const std::vector<AddressPtr> & /*inputs*/,
const std::vector<AddressPtr> & /*workspace*/,
const std::vector<AddressPtr> & /*outputs*/, void * /*stream_ptr*/) {
bool ProfilingKernelMod::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *) {
return true;
}

View File

@ -213,7 +213,6 @@ bool FusionBuildTbeJsonCreator::GenInputsJson(const AnfNodePtr &anf_node, nlohma
input_desc_list_tmp.emplace_back(optional_input_desc);
}
std::vector<nlohmann::json> input_desc_list;
// TODO(jjf): error when reordered op have input not in input_nodes.
TbeAdapter::InputOrderPass<nlohmann::json>(cnode, input_desc_list_tmp, &input_desc_list);
(*compute_json)[kJInputDesc] = input_desc_list;
return true;

View File

@ -180,7 +180,7 @@ void TbeKernelSelect::GetBroadcastPatternKernelInfo(const OpInfo &op_info) {
void TbeKernelSelect::GetReducePatternKernelInfo(const OpInfo &op_info) {
auto reduce_selecter = TbeKernelReduceSelecter(cnode_ptr_);
SupportFormat support_format;
reduce_selecter.GetShapeInfo(&support_format);
(void)reduce_selecter.GetShapeInfo(&support_format);
(void)reduce_selecter.IsReduceSupport5HD(&support_format);
(void)reduce_selecter.IsReduceSupportFracZ(&support_format);
(void)reduce_selecter.IsReduceSupportC1HWNCoC0(&support_format);

View File

@ -24,8 +24,8 @@
namespace mindspore {
namespace opt {
void BatchMatmulDropoutDoMaskV3FusionPass::MatchBatchMatmulDropoutDoMaskV3(
const CNodePtr &cnode, const session::KernelGraph & /* kernel_graph */, FusedNodeRecord *candidate_fusion) {
void BatchMatmulDropoutDoMaskV3FusionPass::MatchBatchMatmulDropoutDoMaskV3(const CNodePtr &cnode,
FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(candidate_fusion);
auto batch_matmul = cnode->input(1);
@ -50,7 +50,7 @@ void BatchMatmulDropoutDoMaskV3FusionPass::MatchSingleFusionPattern(const sessio
MS_EXCEPTION_IF_NULL(cnode);
if (common::AnfAlgo::GetCNodeName(cnode) == kDropoutDoMaskV3OpName) {
MatchBatchMatmulDropoutDoMaskV3(cnode, kernel_graph, candidate_fusion);
MatchBatchMatmulDropoutDoMaskV3(cnode, candidate_fusion);
}
}
}

View File

@ -35,8 +35,7 @@ class BatchMatmulDropoutDoMaskV3FusionPass : public FusionBasePass {
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
private:
void MatchBatchMatmulDropoutDoMaskV3(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion);
void MatchBatchMatmulDropoutDoMaskV3(const CNodePtr &cnode, FusedNodeRecord *candidate_fusion);
};
} // namespace opt
} // namespace mindspore

View File

@ -25,7 +25,6 @@
namespace mindspore {
namespace opt {
void MatmulDropoutDoMaskV3AddFusionPass::MatchMatmulDropoutDoMaskV3Add(const CNodePtr &cnode,
const session::KernelGraph & /* kernel_graph */,
FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(candidate_fusion);
@ -59,7 +58,7 @@ void MatmulDropoutDoMaskV3AddFusionPass::MatchSingleFusionPattern(const session:
MS_EXCEPTION_IF_NULL(cnode);
if (common::AnfAlgo::GetCNodeName(cnode) == kAddOpName) {
MatchMatmulDropoutDoMaskV3Add(cnode, kernel_graph, candidate_fusion);
MatchMatmulDropoutDoMaskV3Add(cnode, candidate_fusion);
}
}
}

View File

@ -35,8 +35,7 @@ class MatmulDropoutDoMaskV3AddFusionPass : public FusionBasePass {
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
private:
void MatchMatmulDropoutDoMaskV3Add(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion);
void MatchMatmulDropoutDoMaskV3Add(const CNodePtr &cnode, FusedNodeRecord *candidate_fusion);
};
} // namespace opt
} // namespace mindspore

View File

@ -24,7 +24,6 @@
namespace mindspore {
namespace opt {
void MatmulEltwiseFusionPass::MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input,
const session::KernelGraph & /* kernel_graph */,
FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(candidate_fusion);
@ -62,7 +61,7 @@ void MatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGrap
auto eltwise_input = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (eltwise_input->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimMatMul)) {
MatchMatmulEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion);
MatchMatmulEltwise(cnode, eltwise_input, candidate_fusion);
}
}
}

View File

@ -37,8 +37,7 @@ class MatmulEltwiseFusionPass : public FusionBasePass {
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
private:
void MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion);
void MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input, FusedNodeRecord *candidate_fusion);
};
} // namespace opt
} // namespace mindspore

View File

@ -903,7 +903,7 @@ bool GPUKernelRuntime::RunOpLaunchKernelDynamic(const session::KernelGraph *grap
KernelLaunchInfo kernel_launch_info;
GenLaunchArgs(*kernel_mod, kernel, &kernel_launch_info);
MS_EXCEPTION_IF_NULL(stream_);
auto ret = kernel_mod->Launch(kernel_launch_info, stream_);
auto ret = kernel_mod->LaunchKernel(kernel_launch_info, stream_);
if (!ret) {
MS_LOG(ERROR) << "Launch kernel failed.";
return false;

View File

@ -1294,7 +1294,7 @@ bool KernelRuntime::LaunchKernelWithPynativeProfiling(kernel::KernelMod *kernel_
start->set_record_stream(stream);
end->set_record_stream(stream);
start->RecordEvent();
bool ret = kernel_mod->Launch(kernel_launch_info, stream);
bool ret = kernel_mod->LaunchKernel(kernel_launch_info, stream);
if (!ret) {
MS_LOG(EXCEPTION) << "Launch kernel failed, kernel name is : " << op_name;
}
@ -1523,7 +1523,7 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod
if (pynative_mode_profiling_flag_) {
ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_launch_info, stream);
} else {
ret = kernel_mod->Launch(kernel_launch_info, stream);
ret = kernel_mod->LaunchKernel(kernel_launch_info, stream);
}
if (!ret) {
return ret;