fix code review

This commit is contained in:
zjun 2020-06-15 10:05:40 +08:00
parent 4b5cbe5d4a
commit b69c200331
5 changed files with 44 additions and 42 deletions

View File

@ -50,7 +50,13 @@ bool SetIOIputSize(const std::shared_ptr<AnfNode> &anf_node, const size_t &input
MS_LOG(EXCEPTION) << "anf_node is not CNode.";
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() < (i + 1)) {
MS_LOG(ERROR) << "cnode inputs size " << cnode->inputs().size() << " is smaller than " << i + 1;
return false;
}
auto input_node = cnode->inputs()[i + 1];
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<ValueNode>()) {
auto value_ptr = GetValueNode(input_node);
auto value = GetValue<std::string>(value_ptr);
@ -103,13 +109,13 @@ bool SetIOSize(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<A
output_size_list.push_back(IntToSize(size_i));
}
kernel_mod_ptr->SetOutputSizeList(output_size_list);
return true;
}
void ParseAttrValue(const std::string &type, const std::string &attr_name, const mindspore::ValuePtr &value,
::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr) {
MS_EXCEPTION_IF_NULL(node_attr);
MS_EXCEPTION_IF_NULL(value);
if (type == "int") {
auto attr_value = GetValue<int>(value);
(*node_attr)[attr_name].set_i(attr_value);
@ -146,6 +152,8 @@ void ParseAttrValue(const std::string &type, const std::string &attr_name, const
}
void SetNodeAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(proto);
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
if (op_name == kInitDataSetQueue) {
op_name = kInitData;
@ -161,15 +169,16 @@ void SetNodeAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *p
MS_EXCEPTION_IF_NULL(primitive);
::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs();
for (const auto &attr_ptr : attrs_ptr) {
MS_EXCEPTION_IF_NULL(attr_ptr);
std::string attr_name = attr_ptr->name();
auto value = primitive->GetAttr(attr_name);
if (value != nullptr) {
if (attr_name == kQueueName || attr_name == kSharedName) {
attr_name = kChannelName;
} else if (attr_name == kSeed) {
attr_name = "seed";
} else if (attr_name == kSeed2) {
attr_name = "seed2";
} else if (attr_name == kSeed0) {
attr_name = kSeed;
} else if (attr_name == kSeed1) {
attr_name = kSeed2;
}
std::string type = attr_ptr->type();
ParseAttrValue(type, attr_name, value, node_attr);
@ -179,6 +188,8 @@ void SetNodeAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *p
}
void SetNodeInputs(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
MS_EXCEPTION_IF_NULL(proto);
MS_EXCEPTION_IF_NULL(anf_node);
size_t input_num = AnfAlgo::GetInputTensorNum(anf_node);
if (input_num == 0) {
MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have input.";
@ -193,6 +204,7 @@ void SetNodeInputs(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef
int32_t input_data_type;
if (input_type == kObjectTypeString) {
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input_node = cnode->inputs()[input_index + 1];
auto value_ptr = GetValueNode(input_node);
auto value = GetValue<std::string>(value_ptr);
@ -203,19 +215,20 @@ void SetNodeInputs(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef
input_shape = AnfAlgo::GetInputDeviceShape(anf_node, input_index);
input_data_type = AicpuOpUtil::MsTypeToProtoType(input_type);
}
mindspore::TensorShape *tensorShape = node_inputs->mutable_tensor_shape();
for (auto item : input_shape) {
mindspore::TensorShape_Dim *dim = tensorShape->add_dim();
dim->set_size((::google::protobuf::int64)item);
}
node_inputs->set_tensor_type((mindspore::DataType)input_data_type);
node_inputs->set_mem_device("HBM");
}
}
void SetNodeOutputs(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
MS_EXCEPTION_IF_NULL(proto);
MS_EXCEPTION_IF_NULL(anf_node);
size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node);
if (output_num == 0) {
MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have output. ";
@ -224,63 +237,55 @@ void SetNodeOutputs(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef
for (size_t output_index = 0; output_index < output_num; output_index++) {
::mindspore::Tensor *node_outputs = proto->add_outputs();
MS_EXCEPTION_IF_NULL(node_outputs);
std::vector<size_t> output_shape = AnfAlgo::GetOutputDeviceShape(anf_node, output_index);
mindspore::TensorShape *tensorShape = node_outputs->mutable_tensor_shape();
MS_EXCEPTION_IF_NULL(tensorShape);
for (auto item : output_shape) {
mindspore::TensorShape_Dim *dim = tensorShape->add_dim();
MS_EXCEPTION_IF_NULL(dim);
dim->set_size((::google::protobuf::int64)item);
}
TypeId output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, output_index);
int32_t output_data_type = AicpuOpUtil::MsTypeToProtoType(output_type);
node_outputs->set_tensor_type((mindspore::DataType)output_data_type);
node_outputs->set_mem_device("HBM");
}
}
void SetNodedefProto(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
MS_LOG(INFO) << "SetNodedefProto entry";
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(proto);
MS_LOG(INFO) << "SetNodedefProto entry";
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
if (op_name == "InitDataSetQueue") {
op_name = "InitData";
if (op_name == kInitDataSetQueue) {
op_name = kInitData;
}
// set op name
proto->set_op(op_name);
// set inputs tensor
SetNodeInputs(anf_node, proto);
// set outputs tensor
SetNodeOutputs(anf_node, proto);
// set node attr
SetNodeAttr(anf_node, proto);
MS_LOG(INFO) << "SetNodedefProto end!";
}
bool CreateNodeDefBytes(const std::shared_ptr<AnfNode> &anf_node,
const std::shared_ptr<AicpuOpKernelMod> &kernel_mod_ptr) {
MS_LOG(INFO) << "CreateNodeDefBytes entry";
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
MS_EXCEPTION_IF_NULL(anf_node);
MS_LOG(INFO) << "CreateNodeDefBytes entry";
mindspore::NodeDef proto;
SetNodedefProto(anf_node, &proto);
std::string nodeDefStr;
if (!proto.SerializeToString(&nodeDefStr)) {
MS_LOG(ERROR) << "Serialize nodeDef to string failed.";
return false;
}
kernel_mod_ptr->SetNodeDef(nodeDefStr);
MS_LOG(INFO) << "CreateNodeDefBytes end!";
return true;
}
@ -288,8 +293,8 @@ bool CreateNodeDefBytes(const std::shared_ptr<AnfNode> &anf_node,
KernelModPtr AicpuOpBuild(const std::shared_ptr<AnfNode> &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
if (op_name == "InitDataSetQueue") {
op_name = "InitData";
if (op_name == kInitDataSetQueue) {
op_name = kInitData;
}
auto kernel_mod_ptr = std::make_shared<AicpuOpKernelMod>();
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);

View File

@ -110,8 +110,8 @@ bool AicpuOpKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::
}
CreateCpuKernelInfo(inputs, outputs);
if (node_name_ == "TopK") {
node_name_ = "TopKV2";
if (node_name_ == kTopK) {
node_name_ = kTopKV2;
}
MS_LOG(INFO) << "Aicpu launch, node_so_:" << node_so_ << ", node name:" << node_name_
<< ", args_size:" << args_.length();
@ -141,8 +141,8 @@ std::vector<TaskInfoPtr> AicpuOpKernelMod::GenTask(const std::vector<AddressPtr>
(void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs),
[](const AddressPtr &output) -> void * { return output->addr; });
if (node_name_ == "TopK") {
node_name_ = "TopKV2";
if (node_name_ == kTopK) {
node_name_ = kTopKV2;
}
AicpuTaskInfoPtr task_info_ptr = make_shared<ge::model_runner::AicpuTaskInfo>(
stream_id, node_so_, node_name_, node_def_str_, input_data_addrs, output_data_addrs);

View File

@ -28,7 +28,6 @@ constexpr auto kInitDataSetQueue = "InitDataSetQueue";
constexpr auto kInitData = "InitData";
constexpr auto kGetNext = "GetNext";
constexpr auto kPrint = "Print";
constexpr auto kOutputTypes = "output_types";
constexpr auto kOutputShapes = "output_shapes";
constexpr auto kChannelName = "channel_name";
@ -36,9 +35,12 @@ constexpr auto kSharedName = "shared_name";
constexpr auto kShapes = "shapes";
constexpr auto kTypes = "types";
constexpr auto kQueueName = "queue_name";
constexpr auto kSeed = "Seed0";
constexpr auto kSeed2 = "Seed1";
constexpr auto kSeed = "seed";
constexpr auto kSeed0 = "Seed0";
constexpr auto kSeed1 = "Seed1";
constexpr auto kSeed2 = "seed2";
constexpr auto kTopK = "TopK";
constexpr auto kTopKV2 = "TopKV2";
struct AicpuParamHead {
uint32_t length; // Total length: include cunstom message

View File

@ -95,12 +95,7 @@ class OpInfo {
OpImplyType imply_type() const { return imply_type_; }
std::string impl_path() const { return impl_path_; }
std::string fusion_type() const { return fusion_type_; }
bool async_flag() const { return async_flag_; }
std::string binfile_name() const { return binfile_name_; }
int compute_cost() const { return compute_cost_; }
std::string kernel_name() const { return kernel_name_; }
bool partial_flag() const { return partial_flag_; }
bool dynamic_format() const { return dynamic_format_; }
OpPattern op_pattern() const { return op_pattern_; }
std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; }
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; }
@ -116,13 +111,10 @@ class OpInfo {
void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; }
void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; }
void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; }
void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; }
void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; }
void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); }
void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); }
void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); }
void set_inputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>> &inputs) { inputs_ptr_ = inputs; }
void set_outputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>> &outputs) { outputs_ptr_ = outputs; }
bool is_ref() const { return !ref_infos_.empty(); }
bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); }
void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); }

View File

@ -103,6 +103,7 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p
{kBroadcast, kBroadcastPattern},
{kReduce, kReducePattern},
{kDynamicFormat, kDynamicFormatPattern}};
MS_EXCEPTION_IF_NULL(op_info);
op_info->set_async_flag(obj.at(kAsyncFlag));
op_info->set_binfile_name(obj.at(kBinfileName));
op_info->set_compute_cost(obj.at(kComputeCost));
@ -199,6 +200,7 @@ bool OpLib::DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type,
bool OpLib::DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr<OpIOInfo> &op_io,
size_t index) {
MS_EXCEPTION_IF_NULL(op_io);
bool ret = true;
try {
std::vector<std::string> dtype;
@ -218,6 +220,7 @@ bool OpLib::DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::sha
bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type,
const std::shared_ptr<OpInfo> &op_info, const nlohmann::json &dtype_format) {
MS_EXCEPTION_IF_NULL(op_info);
bool ret = true;
try {
std::shared_ptr<OpIOInfo> op_io = std::make_shared<OpIOInfo>();