forked from mindspore-Ecosystem/mindspore
!31348 Clean Code
Merge pull request !31348 from tanghuikang/clean_code
This commit is contained in:
commit
35d1782023
|
@ -20,6 +20,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
constexpr const int64_t kFusionGap = 2;
|
||||
bool AdjustDependForParallelOptimizerRecomputeAllGather::Run(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
mindspore::HashMap<int64_t, bool> forward_allgather_recompute_value_in_fusion_group;
|
||||
|
@ -81,12 +82,14 @@ void AdjustDependForParallelOptimizerRecomputeAllGather::IncreaseAllgatherFusion
|
|||
MS_LOG(WARNING) << "Increase the duplicated allgather fusion id";
|
||||
for (auto &adjust_node : parallel_optimizer_recompute_first_fusion_allgathers) {
|
||||
int64_t current_fusion_id = common::AnfAlgo::GetNodeAttr<int64_t>(adjust_node, kAttrFusion);
|
||||
int64_t destination_fusion_id = current_fusion_id + unrecompute_max_fusion_id - recompute_min_fusion_id + 2;
|
||||
int64_t destination_fusion_id =
|
||||
current_fusion_id + unrecompute_max_fusion_id - recompute_min_fusion_id + kFusionGap;
|
||||
common::AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(destination_fusion_id), adjust_node);
|
||||
}
|
||||
for (auto &adjust_node : parallel_optimizer_recompute_allgathers) {
|
||||
int64_t current_fusion_id = common::AnfAlgo::GetNodeAttr<int64_t>(adjust_node, kAttrFusion);
|
||||
int64_t destination_fusion_id = current_fusion_id + unrecompute_max_fusion_id - recompute_min_fusion_id + 2;
|
||||
int64_t destination_fusion_id =
|
||||
current_fusion_id + unrecompute_max_fusion_id - recompute_min_fusion_id + kFusionGap;
|
||||
common::AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(destination_fusion_id), adjust_node);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -71,7 +71,7 @@ Status SomasSolverPre::AddContiguousInfoInMultiMaps(const vector<vector<size_t>>
|
|||
for (size_t i = 0; i < aux.size() - 1; i++) {
|
||||
auto index1 = aux[i];
|
||||
auto index2 = aux[i + 1];
|
||||
if (CheckTensors(pTensors, index1, index2) == FAILED) {
|
||||
if (CheckTensors(pTensors, SizeToUint(index1), SizeToUint(index2)) == FAILED) {
|
||||
return FAILED;
|
||||
}
|
||||
for (size_t sol = 0; sol < vecTensorsMap->size(); sol++) {
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "kernel/common_utils.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t kJsonSuffixLength = 5;
|
||||
namespace {
|
||||
bool CheckHash(const std::string &json_file, const std::string &bin_file, const nlohmann::json &js) {
|
||||
if (js.find("sha256") == js.end()) {
|
||||
|
@ -108,7 +109,7 @@ bool KernelPack::ReadFromJsonFile(const std::string &json_f, const std::string &
|
|||
}
|
||||
|
||||
if (processor == kProcessorCuda) {
|
||||
std::string bin_f = json_f.substr(0, json_f.length() - 5) + ".ptx";
|
||||
std::string bin_f = json_f.substr(0, json_f.length() - kJsonSuffixLength) + ".ptx";
|
||||
std::ifstream kernelbin(bin_f);
|
||||
if (!kernelbin.is_open()) {
|
||||
MS_LOG(ERROR) << "read kernel ptx file error, please check kernelmeta.";
|
||||
|
@ -140,7 +141,7 @@ bool KernelPack::ReadFromJsonFile(const std::string &json_f, const std::string &
|
|||
}
|
||||
|
||||
std::string binfile_suffix = js["binFileSuffix"];
|
||||
std::string bin_f = json_f.substr(0, json_f.length() - 5) + binfile_suffix;
|
||||
std::string bin_f = json_f.substr(0, json_f.length() - kJsonSuffixLength) + binfile_suffix;
|
||||
if (binfile_suffix == ".so") {
|
||||
// change "xx/xx.so" -> "xx/libxx.so"
|
||||
auto sp = bin_f.rfind('/');
|
||||
|
@ -234,7 +235,7 @@ bool KernelPack::LoadKernelMeta(const std::string &json_f) {
|
|||
}
|
||||
ParseKernelJson(js);
|
||||
|
||||
std::string bin_f = json_f.substr(0, json_f.length() - 5) + kernel_json_info_.bin_file_suffix;
|
||||
std::string bin_f = json_f.substr(0, json_f.length() - kJsonSuffixLength) + kernel_json_info_.bin_file_suffix;
|
||||
if (kernel_json_info_.bin_file_suffix == ".so") {
|
||||
// change "xx/xx.so" -> "xx/libxx.so"
|
||||
auto sp = bin_f.rfind('/');
|
||||
|
|
|
@ -187,7 +187,7 @@ class KernelMod {
|
|||
explicit KernelMod(const AnfNodePtr &anf_node_ptr) : anf_node_(anf_node_ptr) {}
|
||||
virtual ~KernelMod() = default;
|
||||
|
||||
bool Launch(const KernelLaunchInfo &kernel_launch_address, void *stream_ptr) {
|
||||
bool LaunchKernel(const KernelLaunchInfo &kernel_launch_address, void *stream_ptr) {
|
||||
return Launch(kernel_launch_address.inputs_, kernel_launch_address.workspaces_, kernel_launch_address.outputs_,
|
||||
stream_ptr);
|
||||
}
|
||||
|
|
|
@ -365,7 +365,7 @@ bool AscendKernelRuntime::Init() {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool AscendKernelRuntime::LoadData(const session::KernelGraph & /* graph */) {
|
||||
bool AscendKernelRuntime::LoadData(const session::KernelGraph &) {
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
MS_LOG(INFO) << "Start load step";
|
||||
MS_EXCEPTION_IF_NULL(debugger_);
|
||||
|
|
|
@ -81,7 +81,7 @@ void HcclDynamicKernel::StaticShapeExecute() {
|
|||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
KernelLaunchInfo kernel_launch_info;
|
||||
KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_launch_info);
|
||||
kernel_mod->Launch(kernel_launch_info, stream_);
|
||||
kernel_mod->LaunchKernel(kernel_launch_info, stream_);
|
||||
}
|
||||
|
||||
void HcclDynamicKernel::Execute() {
|
||||
|
|
|
@ -31,7 +31,7 @@ ProfilerTask::ProfilerTask(const ModelContext &model_context, const std::shared_
|
|||
stream_ = stream_list[stream_id];
|
||||
}
|
||||
|
||||
ProfilerTask::~ProfilerTask() {}
|
||||
ProfilerTask::~ProfilerTask() { stream_ = nullptr; }
|
||||
|
||||
void ProfilerTask::Distribute() {
|
||||
MS_LOG(INFO) << "ProfilerTask Distribute start.";
|
||||
|
|
|
@ -298,8 +298,6 @@ void AscendDeviceContext::Destroy() {
|
|||
graph_event_.clear();
|
||||
rank_id_ = 0;
|
||||
if (runtime_instance_) {
|
||||
// TODO(lzlang): Destroy runtime instance after fully support MindRT, otherwise runtime will be destructed
|
||||
// repeatedly. runtime_instance_->ReleaseDeviceRes();
|
||||
runtime_instance_ = nullptr;
|
||||
}
|
||||
AscendGraphOptimization::GetInstance().Reset();
|
||||
|
|
|
@ -40,17 +40,6 @@ static constexpr const char *kHcclAlgoOption = "HCCL_algorithm";
|
|||
return HcclResult::HCCL_E_RESERVED; \
|
||||
}
|
||||
|
||||
#define CHECK_EXCUTION_MODE() \
|
||||
do { \
|
||||
auto hccl_mode = GetCurrentHcclMode(); \
|
||||
if (hccl_mode != hccl_mode_) { \
|
||||
MS_LOG(EXCEPTION) << "HCCL is initialized in " << GetHcclModeString(hccl_mode_) \
|
||||
<< " but current execution mode is " << GetHcclModeString(hccl_mode) \
|
||||
<< ". Please set the execution mode before HCCL init(), and then do not " \
|
||||
"change it in the subsequent script"; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
static std::map<std::string, std::string> GenHcclOptions(uint32_t device_id, std::string_view rank_id,
|
||||
std::string_view rank_file) {
|
||||
auto env_deploy_mode = mindspore::common::GetEnv(kHcclDeployModeEnv);
|
||||
|
@ -159,6 +148,16 @@ HcclMode HcclAdapter::GetCurrentHcclMode() const {
|
|||
}
|
||||
}
|
||||
|
||||
void HcclAdapter::CheckExcutionMode() const {
|
||||
auto hccl_mode = GetCurrentHcclMode();
|
||||
if (hccl_mode != hccl_mode_) {
|
||||
MS_LOG(EXCEPTION) << "HCCL is initialized in " << GetHcclModeString(hccl_mode_) << " but current execution mode is "
|
||||
<< GetHcclModeString(hccl_mode)
|
||||
<< ". Please set the execution mode before HCCL init(), and then do not change it in the "
|
||||
"subsequent script";
|
||||
}
|
||||
}
|
||||
|
||||
std::string HcclAdapter::GetHcclModeString(HcclMode hccl_mode) {
|
||||
static std::map<HcclMode, std::string> kHcclModeString = {
|
||||
{HcclMode::kGraph, "GRAPH_MODE"},
|
||||
|
@ -307,14 +306,14 @@ std::string HcclAdapter::GetHcclType(const AnfNodePtr &node) {
|
|||
|
||||
HcclResult HcclAdapter::HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root,
|
||||
aclrtStream stream) const {
|
||||
CHECK_EXCUTION_MODE();
|
||||
CheckExcutionMode();
|
||||
CHECK_SYMBOL_NULL(launch_hccl_broadcast_);
|
||||
return launch_hccl_broadcast_(buf, count, dataType, root, hccl_comm_, stream);
|
||||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclAllReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType,
|
||||
HcclReduceOp op, aclrtStream stream, const std::string &group) const {
|
||||
CHECK_EXCUTION_MODE();
|
||||
CheckExcutionMode();
|
||||
CHECK_SYMBOL_NULL(launch_hccl_all_reduce_);
|
||||
auto hccl_comm = GetHcomm(group);
|
||||
MS_EXCEPTION_IF_NULL(hccl_comm);
|
||||
|
@ -323,7 +322,7 @@ HcclResult HcclAdapter::HcclAllReduce(void *send_buf, void *recv_buf, uint64_t c
|
|||
|
||||
HcclResult HcclAdapter::HcclReduceScatter(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType,
|
||||
HcclReduceOp op, aclrtStream stream, const std::string &group) const {
|
||||
CHECK_EXCUTION_MODE();
|
||||
CheckExcutionMode();
|
||||
CHECK_SYMBOL_NULL(launch_hccl_reduce_scatter_);
|
||||
auto hccl_comm = GetHcomm(group);
|
||||
MS_EXCEPTION_IF_NULL(hccl_comm);
|
||||
|
@ -332,7 +331,7 @@ HcclResult HcclAdapter::HcclReduceScatter(void *send_buf, void *recv_buf, uint64
|
|||
|
||||
HcclResult HcclAdapter::HcclAllGather(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType,
|
||||
aclrtStream stream, const std::string &group) const {
|
||||
CHECK_EXCUTION_MODE();
|
||||
CheckExcutionMode();
|
||||
CHECK_SYMBOL_NULL(launch_hccl_all_gather_);
|
||||
auto hccl_comm = GetHcomm(group);
|
||||
MS_EXCEPTION_IF_NULL(hccl_comm);
|
||||
|
@ -341,7 +340,7 @@ HcclResult HcclAdapter::HcclAllGather(void *send_buf, void *recv_buf, uint64_t c
|
|||
|
||||
HcclResult HcclAdapter::HcclSend(void *send_buf, uint64_t count, HcclDataType dataType, uint32_t destRank,
|
||||
aclrtStream stream, const std::string &group) const {
|
||||
CHECK_EXCUTION_MODE();
|
||||
CheckExcutionMode();
|
||||
CHECK_SYMBOL_NULL(launch_hccl_send_);
|
||||
auto hccl_comm = GetHcomm(group);
|
||||
MS_EXCEPTION_IF_NULL(hccl_comm);
|
||||
|
@ -350,7 +349,7 @@ HcclResult HcclAdapter::HcclSend(void *send_buf, uint64_t count, HcclDataType da
|
|||
|
||||
HcclResult HcclAdapter::HcclRecv(void *recv_buf, uint64_t count, HcclDataType dataType, uint32_t srcRank,
|
||||
aclrtStream stream, const std::string &group) const {
|
||||
CHECK_EXCUTION_MODE();
|
||||
CheckExcutionMode();
|
||||
CHECK_SYMBOL_NULL(launch_hccl_recv_);
|
||||
auto hccl_comm = GetHcomm(group);
|
||||
MS_EXCEPTION_IF_NULL(hccl_comm);
|
||||
|
@ -474,7 +473,7 @@ bool HcclAdapter::FinalizeHcclComm() {
|
|||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclCreateGroup(const std::string &group, uint32_t rank_num, uint32_t *rank_ids) const {
|
||||
CHECK_EXCUTION_MODE();
|
||||
CheckExcutionMode();
|
||||
CHECK_SYMBOL_NULL(hccl_create_group_);
|
||||
return hccl_create_group_(group.c_str(), rank_num, rank_ids);
|
||||
}
|
||||
|
@ -485,25 +484,25 @@ HcclResult HcclAdapter::HcclDestroyGroup(const std::string &group) const {
|
|||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclGetRankId(uint32_t *rank_id) const {
|
||||
CHECK_EXCUTION_MODE();
|
||||
CheckExcutionMode();
|
||||
CHECK_SYMBOL_NULL(single_op_hccl_get_rank_id_);
|
||||
return single_op_hccl_get_rank_id_(hccl_comm_, rank_id);
|
||||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclGetRankSize(uint32_t *rank_size) const {
|
||||
CHECK_EXCUTION_MODE();
|
||||
CheckExcutionMode();
|
||||
CHECK_SYMBOL_NULL(single_op_hccl_get_rank_size_);
|
||||
return single_op_hccl_get_rank_size_(hccl_comm_, rank_size);
|
||||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclGetRankId(const std::string &group, uint32_t *rank_id) const {
|
||||
CHECK_EXCUTION_MODE();
|
||||
CheckExcutionMode();
|
||||
CHECK_SYMBOL_NULL(hccl_get_rank_id_);
|
||||
return hccl_get_rank_id_(group.c_str(), rank_id);
|
||||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclGetRankSize(const std::string &group, uint32_t *rank_size) const {
|
||||
CHECK_EXCUTION_MODE();
|
||||
CheckExcutionMode();
|
||||
CHECK_SYMBOL_NULL(hccl_get_rank_size_);
|
||||
return hccl_get_rank_size_(group.c_str(), rank_size);
|
||||
}
|
||||
|
@ -537,13 +536,13 @@ bool HcclAdapter::FinalizeHcclExec() {
|
|||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const {
|
||||
CHECK_EXCUTION_MODE();
|
||||
CheckExcutionMode();
|
||||
CHECK_SYMBOL_NULL(hccl_exec_enqueue_op_);
|
||||
return hccl_exec_enqueue_op_(op_info, callback);
|
||||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclExecAllToAllv(const ::HcomAllToAllVParams ¶ms, const HExecCallBack &callback) const {
|
||||
CHECK_EXCUTION_MODE();
|
||||
CheckExcutionMode();
|
||||
CHECK_SYMBOL_NULL(hccl_exec_enqueue_all_to_all_v_);
|
||||
return hccl_exec_enqueue_all_to_all_v_(params, callback);
|
||||
}
|
||||
|
|
|
@ -106,6 +106,7 @@ class HcclAdapter {
|
|||
bool FinalizeHcclExec();
|
||||
|
||||
HcclMode GetCurrentHcclMode() const;
|
||||
void CheckExcutionMode() const;
|
||||
static std::string GetHcclModeString(HcclMode hccl_mode);
|
||||
|
||||
void *plugin_handle_ = nullptr;
|
||||
|
|
|
@ -28,8 +28,8 @@ AssignKernel::AssignKernel() {}
|
|||
|
||||
AssignKernel::~AssignKernel() {}
|
||||
|
||||
bool AssignKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> & /* workspace */,
|
||||
const std::vector<AddressPtr> & /*outputs*/, void *stream_ptr) {
|
||||
bool AssignKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &, void *stream_ptr) {
|
||||
if (inputs.size() != 2) {
|
||||
MS_LOG(ERROR) << "inputs size is not two";
|
||||
return false;
|
||||
|
|
|
@ -43,8 +43,8 @@ bool LabelGotoKernel::Init(const AnfNodePtr &anf_node) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool LabelGotoKernel::Launch(const std::vector<AddressPtr> & /*inputs*/, const std::vector<AddressPtr> & /*workspace*/,
|
||||
const std::vector<AddressPtr> & /*outputs*/, void * /*stream_ptr*/) {
|
||||
bool LabelGotoKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &, void *) {
|
||||
MS_LOG(INFO) << "LabelGotoKernel launch";
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -50,9 +50,8 @@ bool ProfilingKernelMod::Init(const AnfNodePtr &anf_node) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ProfilingKernelMod::Launch(const std::vector<AddressPtr> & /*inputs*/,
|
||||
const std::vector<AddressPtr> & /*workspace*/,
|
||||
const std::vector<AddressPtr> & /*outputs*/, void * /*stream_ptr*/) {
|
||||
bool ProfilingKernelMod::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &, void *) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -213,7 +213,6 @@ bool FusionBuildTbeJsonCreator::GenInputsJson(const AnfNodePtr &anf_node, nlohma
|
|||
input_desc_list_tmp.emplace_back(optional_input_desc);
|
||||
}
|
||||
std::vector<nlohmann::json> input_desc_list;
|
||||
// TODO(jjf): error when reordered op have input not in input_nodes.
|
||||
TbeAdapter::InputOrderPass<nlohmann::json>(cnode, input_desc_list_tmp, &input_desc_list);
|
||||
(*compute_json)[kJInputDesc] = input_desc_list;
|
||||
return true;
|
||||
|
|
|
@ -180,7 +180,7 @@ void TbeKernelSelect::GetBroadcastPatternKernelInfo(const OpInfo &op_info) {
|
|||
void TbeKernelSelect::GetReducePatternKernelInfo(const OpInfo &op_info) {
|
||||
auto reduce_selecter = TbeKernelReduceSelecter(cnode_ptr_);
|
||||
SupportFormat support_format;
|
||||
reduce_selecter.GetShapeInfo(&support_format);
|
||||
(void)reduce_selecter.GetShapeInfo(&support_format);
|
||||
(void)reduce_selecter.IsReduceSupport5HD(&support_format);
|
||||
(void)reduce_selecter.IsReduceSupportFracZ(&support_format);
|
||||
(void)reduce_selecter.IsReduceSupportC1HWNCoC0(&support_format);
|
||||
|
|
|
@ -24,8 +24,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
void BatchMatmulDropoutDoMaskV3FusionPass::MatchBatchMatmulDropoutDoMaskV3(
|
||||
const CNodePtr &cnode, const session::KernelGraph & /* kernel_graph */, FusedNodeRecord *candidate_fusion) {
|
||||
void BatchMatmulDropoutDoMaskV3FusionPass::MatchBatchMatmulDropoutDoMaskV3(const CNodePtr &cnode,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
auto batch_matmul = cnode->input(1);
|
||||
|
@ -50,7 +50,7 @@ void BatchMatmulDropoutDoMaskV3FusionPass::MatchSingleFusionPattern(const sessio
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
if (common::AnfAlgo::GetCNodeName(cnode) == kDropoutDoMaskV3OpName) {
|
||||
MatchBatchMatmulDropoutDoMaskV3(cnode, kernel_graph, candidate_fusion);
|
||||
MatchBatchMatmulDropoutDoMaskV3(cnode, candidate_fusion);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,8 +35,7 @@ class BatchMatmulDropoutDoMaskV3FusionPass : public FusionBasePass {
|
|||
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
|
||||
|
||||
private:
|
||||
void MatchBatchMatmulDropoutDoMaskV3(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion);
|
||||
void MatchBatchMatmulDropoutDoMaskV3(const CNodePtr &cnode, FusedNodeRecord *candidate_fusion);
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
void MatmulDropoutDoMaskV3AddFusionPass::MatchMatmulDropoutDoMaskV3Add(const CNodePtr &cnode,
|
||||
const session::KernelGraph & /* kernel_graph */,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
|
@ -59,7 +58,7 @@ void MatmulDropoutDoMaskV3AddFusionPass::MatchSingleFusionPattern(const session:
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
if (common::AnfAlgo::GetCNodeName(cnode) == kAddOpName) {
|
||||
MatchMatmulDropoutDoMaskV3Add(cnode, kernel_graph, candidate_fusion);
|
||||
MatchMatmulDropoutDoMaskV3Add(cnode, candidate_fusion);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,8 +35,7 @@ class MatmulDropoutDoMaskV3AddFusionPass : public FusionBasePass {
|
|||
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
|
||||
|
||||
private:
|
||||
void MatchMatmulDropoutDoMaskV3Add(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion);
|
||||
void MatchMatmulDropoutDoMaskV3Add(const CNodePtr &cnode, FusedNodeRecord *candidate_fusion);
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,7 +24,6 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
void MatmulEltwiseFusionPass::MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input,
|
||||
const session::KernelGraph & /* kernel_graph */,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
|
@ -62,7 +61,7 @@ void MatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGrap
|
|||
auto eltwise_input = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||
if (eltwise_input->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimMatMul)) {
|
||||
MatchMatmulEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion);
|
||||
MatchMatmulEltwise(cnode, eltwise_input, candidate_fusion);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,8 +37,7 @@ class MatmulEltwiseFusionPass : public FusionBasePass {
|
|||
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
|
||||
|
||||
private:
|
||||
void MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion);
|
||||
void MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input, FusedNodeRecord *candidate_fusion);
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -903,7 +903,7 @@ bool GPUKernelRuntime::RunOpLaunchKernelDynamic(const session::KernelGraph *grap
|
|||
KernelLaunchInfo kernel_launch_info;
|
||||
GenLaunchArgs(*kernel_mod, kernel, &kernel_launch_info);
|
||||
MS_EXCEPTION_IF_NULL(stream_);
|
||||
auto ret = kernel_mod->Launch(kernel_launch_info, stream_);
|
||||
auto ret = kernel_mod->LaunchKernel(kernel_launch_info, stream_);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Launch kernel failed.";
|
||||
return false;
|
||||
|
|
|
@ -1294,7 +1294,7 @@ bool KernelRuntime::LaunchKernelWithPynativeProfiling(kernel::KernelMod *kernel_
|
|||
start->set_record_stream(stream);
|
||||
end->set_record_stream(stream);
|
||||
start->RecordEvent();
|
||||
bool ret = kernel_mod->Launch(kernel_launch_info, stream);
|
||||
bool ret = kernel_mod->LaunchKernel(kernel_launch_info, stream);
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "Launch kernel failed, kernel name is : " << op_name;
|
||||
}
|
||||
|
@ -1523,7 +1523,7 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod
|
|||
if (pynative_mode_profiling_flag_) {
|
||||
ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_launch_info, stream);
|
||||
} else {
|
||||
ret = kernel_mod->Launch(kernel_launch_info, stream);
|
||||
ret = kernel_mod->LaunchKernel(kernel_launch_info, stream);
|
||||
}
|
||||
if (!ret) {
|
||||
return ret;
|
||||
|
|
Loading…
Reference in New Issue