clean code
This commit is contained in:
parent
cf17b1134a
commit
a4137191c4
|
@ -240,7 +240,7 @@ enum class KernelModType {
|
|||
HostKernelMod,
|
||||
};
|
||||
|
||||
enum KernelErrorCode { KRET_OK = 0, KRET_RESIZE_FAILED = 1, KRET_UNKNOWN_SHAPE = 2, KRET_UNKNOWN_OUT_SHAPE = 3 };
|
||||
enum KernelErrorCode : int { KRET_OK = 0, KRET_RESIZE_FAILED = 1, KRET_UNKNOWN_SHAPE = 2, KRET_UNKNOWN_OUT_SHAPE = 3 };
|
||||
|
||||
class KernelMod {
|
||||
public:
|
||||
|
|
|
@ -184,5 +184,24 @@ void ReshapeKernelMod::Execute(const std::vector<AddressPtr> &inputs, const std:
|
|||
}
|
||||
MS_LOG(INFO) << "Execute host ReshapeKernel End";
|
||||
}
|
||||
|
||||
bool ReshapeKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
auto node = anf_node_.lock();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (stream_ == nullptr) {
|
||||
stream_ = stream_ptr;
|
||||
}
|
||||
try {
|
||||
Execute(inputs, outputs);
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "ReshapeKernelMod Launch failed. node: " << cnode->fullname_with_scope() << ", Error message is "
|
||||
<< e.what();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,23 +26,7 @@ class ReshapeKernelMod : public HostKernelMod {
|
|||
ReshapeKernelMod() = default;
|
||||
~ReshapeKernelMod() override = default;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
auto node = anf_node_.lock();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (stream_ == nullptr) {
|
||||
stream_ = stream_ptr;
|
||||
}
|
||||
try {
|
||||
Execute(inputs, outputs);
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "ReshapeKernelMod Launch failed. node: " << cnode->fullname_with_scope() << ", Error message is "
|
||||
<< e.what();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
|
||||
private:
|
||||
void Execute() const;
|
||||
|
|
|
@ -37,8 +37,8 @@ GetNextDynamic::GetNextDynamic() {}
|
|||
|
||||
GetNextDynamic::~GetNextDynamic() {}
|
||||
|
||||
bool GetNextDynamic::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
bool GetNextDynamic::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &, void *) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -53,9 +53,8 @@ bool GetNextDynamic::Init(const mindspore::AnfNodePtr &anf_node) {
|
|||
return true;
|
||||
}
|
||||
|
||||
int GetNextDynamic::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &others) {
|
||||
int GetNextDynamic::Resize(const BaseOperatorPtr &, const std::vector<KernelTensorPtr> &,
|
||||
const std::vector<KernelTensorPtr> &, const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
auto data_kernel = anf_node_.lock();
|
||||
bool ret = device::PopDataFromDataQueue(data_kernel);
|
||||
if (!ret) {
|
||||
|
@ -77,8 +76,8 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> GetNextDynamicDesc::GetKer
|
|||
std::vector<TypeId> output_type;
|
||||
for (size_t idx = 0; idx < output_num; ++idx) {
|
||||
auto data_type = common::AnfAlgo::GetOutputInferDataType(kernel_node, idx);
|
||||
output_type.push_back(data_type);
|
||||
output_format.push_back(kOpFormat_DEFAULT);
|
||||
output_type.emplace_back(data_type);
|
||||
output_format.emplace_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
|
||||
builder.SetOutputsFormat(output_format);
|
||||
|
|
|
@ -35,9 +35,9 @@ class GetNextDynamic : public RtKernel {
|
|||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &others = std::map<uint32_t, tensor::TensorPtr>());
|
||||
const std::map<uint32_t, tensor::TensorPtr> &others = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &, uint32_t) {
|
||||
const std::vector<AddressPtr> &, uint32_t) override {
|
||||
std::vector<TaskInfoPtr> res;
|
||||
return res;
|
||||
}
|
||||
|
@ -47,7 +47,7 @@ class GetNextDynamicDesc : public RtKerDesc {
|
|||
public:
|
||||
GetNextDynamicDesc();
|
||||
~GetNextDynamicDesc() override;
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> GetKernelInfo(const CNodePtr &kernel_node = nullptr) override;
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> GetKernelInfo(const CNodePtr &kernel_node) override;
|
||||
};
|
||||
|
||||
MS_REG_RTKERNEL_DESC(getnext, GetNextDynamicDesc);
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_mod.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_utils.h"
|
||||
|
|
|
@ -123,7 +123,7 @@ bool FusionBuildTbeJsonCreator::CheckInput(const FusionScopeInfo &fusion_scope_i
|
|||
void FusionBuildTbeJsonCreator::GenDataJson(const std::vector<AnfNodePtr> &compute_nodes,
|
||||
const std::vector<nlohmann::json> &compute_json,
|
||||
std::vector<nlohmann::json> *op_list_json,
|
||||
const ANodeFusionDataTypeMap &spec_data_input) {
|
||||
const ANodeFusionDataTypeMap &spec_data_input) const {
|
||||
MS_EXCEPTION_IF_NULL(op_list_json);
|
||||
MS_LOG(DEBUG) << "Start.";
|
||||
std::vector<std::string> compute_nodes_fullname;
|
||||
|
@ -280,7 +280,7 @@ bool FusionBuildTbeJsonCreator::GenOutputsJson(const AnfNodePtr &anf_node, nlohm
|
|||
}
|
||||
|
||||
void FusionBuildTbeJsonCreator::GenReusedOutputDesc(const AnfNodePtr &anf_node, size_t index, size_t output_index,
|
||||
nlohmann::json *output_desc, size_t out_size) {
|
||||
nlohmann::json *output_desc, size_t out_size) const {
|
||||
GenDesJsonCommon(output_desc);
|
||||
std::string output_desc_name = anf_node->fullname_with_scope() + "_" + std::to_string(index);
|
||||
(*output_desc)[kJName] = output_desc_name;
|
||||
|
|
|
@ -33,9 +33,9 @@ class FusionBuildTbeJsonCreator : public TbeJsonCreator {
|
|||
bool GenOutputsJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json) override;
|
||||
std::vector<size_t> GetDescOutputIndex(const std::vector<int64_t> &output_used_nums) const;
|
||||
void GenReusedOutputDesc(const AnfNodePtr &anf_node, size_t index, size_t output_index, nlohmann::json *output_desc,
|
||||
size_t out_size);
|
||||
size_t out_size) const;
|
||||
void GenDataJson(const std::vector<AnfNodePtr> &compute_nodes, const std::vector<nlohmann::json> &compute_json,
|
||||
std::vector<nlohmann::json> *op_list_json, const ANodeFusionDataTypeMap &spec_data_input);
|
||||
std::vector<nlohmann::json> *op_list_json, const ANodeFusionDataTypeMap &spec_data_input) const;
|
||||
bool AttrsJsonPostProcessing(const AnfNodePtr &, const OpInfoPtr &, nlohmann::json *) override;
|
||||
void GenOtherJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json) override;
|
||||
|
||||
|
|
Loading…
Reference in New Issue