clean code

This commit is contained in:
jiaorui 2022-07-19 17:28:20 +08:00
parent cf17b1134a
commit a4137191c4
8 changed files with 34 additions and 33 deletions

View File

@ -240,7 +240,7 @@ enum class KernelModType {
HostKernelMod,
};
enum KernelErrorCode { KRET_OK = 0, KRET_RESIZE_FAILED = 1, KRET_UNKNOWN_SHAPE = 2, KRET_UNKNOWN_OUT_SHAPE = 3 };
enum KernelErrorCode : int { KRET_OK = 0, KRET_RESIZE_FAILED = 1, KRET_UNKNOWN_SHAPE = 2, KRET_UNKNOWN_OUT_SHAPE = 3 };
class KernelMod {
public:

View File

@ -184,5 +184,24 @@ void ReshapeKernelMod::Execute(const std::vector<AddressPtr> &inputs, const std:
}
MS_LOG(INFO) << "Execute host ReshapeKernel End";
}
bool ReshapeKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
auto node = anf_node_.lock();
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (stream_ == nullptr) {
stream_ = stream_ptr;
}
try {
Execute(inputs, outputs);
} catch (const std::exception &e) {
MS_LOG(ERROR) << "ReshapeKernelMod Launch failed. node: " << cnode->fullname_with_scope() << ", Error message is "
<< e.what();
return false;
}
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -26,23 +26,7 @@ class ReshapeKernelMod : public HostKernelMod {
ReshapeKernelMod() = default;
~ReshapeKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
auto node = anf_node_.lock();
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (stream_ == nullptr) {
stream_ = stream_ptr;
}
try {
Execute(inputs, outputs);
} catch (const std::exception &e) {
MS_LOG(ERROR) << "ReshapeKernelMod Launch failed. node: " << cnode->fullname_with_scope() << ", Error message is "
<< e.what();
return false;
}
return true;
}
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
private:
void Execute() const;

View File

@ -37,8 +37,8 @@ GetNextDynamic::GetNextDynamic() {}
GetNextDynamic::~GetNextDynamic() {}
bool GetNextDynamic::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
bool GetNextDynamic::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *) {
return true;
}
@ -53,9 +53,8 @@ bool GetNextDynamic::Init(const mindspore::AnfNodePtr &anf_node) {
return true;
}
int GetNextDynamic::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &others) {
int GetNextDynamic::Resize(const BaseOperatorPtr &, const std::vector<KernelTensorPtr> &,
const std::vector<KernelTensorPtr> &, const std::map<uint32_t, tensor::TensorPtr> &) {
auto data_kernel = anf_node_.lock();
bool ret = device::PopDataFromDataQueue(data_kernel);
if (!ret) {
@ -77,8 +76,8 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> GetNextDynamicDesc::GetKer
std::vector<TypeId> output_type;
for (size_t idx = 0; idx < output_num; ++idx) {
auto data_type = common::AnfAlgo::GetOutputInferDataType(kernel_node, idx);
output_type.push_back(data_type);
output_format.push_back(kOpFormat_DEFAULT);
output_type.emplace_back(data_type);
output_format.emplace_back(kOpFormat_DEFAULT);
}
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
builder.SetOutputsFormat(output_format);

View File

@ -35,9 +35,9 @@ class GetNextDynamic : public RtKernel {
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &others = std::map<uint32_t, tensor::TensorPtr>());
const std::map<uint32_t, tensor::TensorPtr> &others = std::map<uint32_t, tensor::TensorPtr>()) override;
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, uint32_t) {
const std::vector<AddressPtr> &, uint32_t) override {
std::vector<TaskInfoPtr> res;
return res;
}
@ -47,7 +47,7 @@ class GetNextDynamicDesc : public RtKerDesc {
public:
GetNextDynamicDesc();
~GetNextDynamicDesc() override;
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> GetKernelInfo(const CNodePtr &kernel_node = nullptr) override;
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> GetKernelInfo(const CNodePtr &kernel_node) override;
};
MS_REG_RTKERNEL_DESC(getnext, GetNextDynamicDesc);

View File

@ -20,7 +20,6 @@
#include <memory>
#include <string>
#include <vector>
#include <utility>
#include <map>
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_mod.h"
#include "plugin/device/ascend/kernel/tbe/tbe_utils.h"

View File

@ -123,7 +123,7 @@ bool FusionBuildTbeJsonCreator::CheckInput(const FusionScopeInfo &fusion_scope_i
void FusionBuildTbeJsonCreator::GenDataJson(const std::vector<AnfNodePtr> &compute_nodes,
const std::vector<nlohmann::json> &compute_json,
std::vector<nlohmann::json> *op_list_json,
const ANodeFusionDataTypeMap &spec_data_input) {
const ANodeFusionDataTypeMap &spec_data_input) const {
MS_EXCEPTION_IF_NULL(op_list_json);
MS_LOG(DEBUG) << "Start.";
std::vector<std::string> compute_nodes_fullname;
@ -280,7 +280,7 @@ bool FusionBuildTbeJsonCreator::GenOutputsJson(const AnfNodePtr &anf_node, nlohm
}
void FusionBuildTbeJsonCreator::GenReusedOutputDesc(const AnfNodePtr &anf_node, size_t index, size_t output_index,
nlohmann::json *output_desc, size_t out_size) {
nlohmann::json *output_desc, size_t out_size) const {
GenDesJsonCommon(output_desc);
std::string output_desc_name = anf_node->fullname_with_scope() + "_" + std::to_string(index);
(*output_desc)[kJName] = output_desc_name;

View File

@ -33,9 +33,9 @@ class FusionBuildTbeJsonCreator : public TbeJsonCreator {
bool GenOutputsJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json) override;
std::vector<size_t> GetDescOutputIndex(const std::vector<int64_t> &output_used_nums) const;
void GenReusedOutputDesc(const AnfNodePtr &anf_node, size_t index, size_t output_index, nlohmann::json *output_desc,
size_t out_size);
size_t out_size) const;
void GenDataJson(const std::vector<AnfNodePtr> &compute_nodes, const std::vector<nlohmann::json> &compute_json,
std::vector<nlohmann::json> *op_list_json, const ANodeFusionDataTypeMap &spec_data_input);
std::vector<nlohmann::json> *op_list_json, const ANodeFusionDataTypeMap &spec_data_input) const;
bool AttrsJsonPostProcessing(const AnfNodePtr &, const OpInfoPtr &, nlohmann::json *) override;
void GenOtherJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json) override;