forked from mindspore-Ecosystem/mindspore
commit
518e955260
|
@ -50,7 +50,13 @@ bool SetIOIputSize(const std::shared_ptr<AnfNode> &anf_node, const size_t &input
|
|||
MS_LOG(EXCEPTION) << "anf_node is not CNode.";
|
||||
}
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().size() < (i + 1)) {
|
||||
MS_LOG(ERROR) << "cnode inputs size " << cnode->inputs().size() << " is smaller than " << i + 1;
|
||||
return false;
|
||||
}
|
||||
auto input_node = cnode->inputs()[i + 1];
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (input_node->isa<ValueNode>()) {
|
||||
auto value_ptr = GetValueNode(input_node);
|
||||
auto value = GetValue<std::string>(value_ptr);
|
||||
|
@ -103,13 +109,13 @@ bool SetIOSize(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<A
|
|||
output_size_list.push_back(IntToSize(size_i));
|
||||
}
|
||||
kernel_mod_ptr->SetOutputSizeList(output_size_list);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void ParseAttrValue(const std::string &type, const std::string &attr_name, const mindspore::ValuePtr &value,
|
||||
::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr) {
|
||||
MS_EXCEPTION_IF_NULL(node_attr);
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (type == "int") {
|
||||
auto attr_value = GetValue<int>(value);
|
||||
(*node_attr)[attr_name].set_i(attr_value);
|
||||
|
@ -146,6 +152,8 @@ void ParseAttrValue(const std::string &type, const std::string &attr_name, const
|
|||
}
|
||||
|
||||
void SetNodeAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(proto);
|
||||
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
|
||||
if (op_name == kInitDataSetQueue) {
|
||||
op_name = kInitData;
|
||||
|
@ -161,15 +169,16 @@ void SetNodeAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *p
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs();
|
||||
for (const auto &attr_ptr : attrs_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(attr_ptr);
|
||||
std::string attr_name = attr_ptr->name();
|
||||
auto value = primitive->GetAttr(attr_name);
|
||||
if (value != nullptr) {
|
||||
if (attr_name == kQueueName || attr_name == kSharedName) {
|
||||
attr_name = kChannelName;
|
||||
} else if (attr_name == kSeed) {
|
||||
attr_name = "seed";
|
||||
} else if (attr_name == kSeed2) {
|
||||
attr_name = "seed2";
|
||||
} else if (attr_name == kSeed0) {
|
||||
attr_name = kSeed;
|
||||
} else if (attr_name == kSeed1) {
|
||||
attr_name = kSeed2;
|
||||
}
|
||||
std::string type = attr_ptr->type();
|
||||
ParseAttrValue(type, attr_name, value, node_attr);
|
||||
|
@ -179,6 +188,8 @@ void SetNodeAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *p
|
|||
}
|
||||
|
||||
void SetNodeInputs(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
|
||||
MS_EXCEPTION_IF_NULL(proto);
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(anf_node);
|
||||
if (input_num == 0) {
|
||||
MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have input.";
|
||||
|
@ -193,6 +204,7 @@ void SetNodeInputs(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef
|
|||
int32_t input_data_type;
|
||||
if (input_type == kObjectTypeString) {
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto input_node = cnode->inputs()[input_index + 1];
|
||||
auto value_ptr = GetValueNode(input_node);
|
||||
auto value = GetValue<std::string>(value_ptr);
|
||||
|
@ -203,19 +215,20 @@ void SetNodeInputs(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef
|
|||
input_shape = AnfAlgo::GetInputDeviceShape(anf_node, input_index);
|
||||
input_data_type = AicpuOpUtil::MsTypeToProtoType(input_type);
|
||||
}
|
||||
|
||||
mindspore::TensorShape *tensorShape = node_inputs->mutable_tensor_shape();
|
||||
for (auto item : input_shape) {
|
||||
mindspore::TensorShape_Dim *dim = tensorShape->add_dim();
|
||||
dim->set_size((::google::protobuf::int64)item);
|
||||
}
|
||||
|
||||
node_inputs->set_tensor_type((mindspore::DataType)input_data_type);
|
||||
|
||||
node_inputs->set_mem_device("HBM");
|
||||
}
|
||||
}
|
||||
|
||||
void SetNodeOutputs(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
|
||||
MS_EXCEPTION_IF_NULL(proto);
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node);
|
||||
if (output_num == 0) {
|
||||
MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have output. ";
|
||||
|
@ -224,63 +237,55 @@ void SetNodeOutputs(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef
|
|||
|
||||
for (size_t output_index = 0; output_index < output_num; output_index++) {
|
||||
::mindspore::Tensor *node_outputs = proto->add_outputs();
|
||||
MS_EXCEPTION_IF_NULL(node_outputs);
|
||||
std::vector<size_t> output_shape = AnfAlgo::GetOutputDeviceShape(anf_node, output_index);
|
||||
mindspore::TensorShape *tensorShape = node_outputs->mutable_tensor_shape();
|
||||
MS_EXCEPTION_IF_NULL(tensorShape);
|
||||
for (auto item : output_shape) {
|
||||
mindspore::TensorShape_Dim *dim = tensorShape->add_dim();
|
||||
MS_EXCEPTION_IF_NULL(dim);
|
||||
dim->set_size((::google::protobuf::int64)item);
|
||||
}
|
||||
|
||||
TypeId output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, output_index);
|
||||
|
||||
int32_t output_data_type = AicpuOpUtil::MsTypeToProtoType(output_type);
|
||||
node_outputs->set_tensor_type((mindspore::DataType)output_data_type);
|
||||
|
||||
node_outputs->set_mem_device("HBM");
|
||||
}
|
||||
}
|
||||
|
||||
void SetNodedefProto(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
|
||||
MS_LOG(INFO) << "SetNodedefProto entry";
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(proto);
|
||||
|
||||
MS_LOG(INFO) << "SetNodedefProto entry";
|
||||
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
|
||||
if (op_name == "InitDataSetQueue") {
|
||||
op_name = "InitData";
|
||||
if (op_name == kInitDataSetQueue) {
|
||||
op_name = kInitData;
|
||||
}
|
||||
// set op name
|
||||
proto->set_op(op_name);
|
||||
|
||||
// set inputs tensor
|
||||
SetNodeInputs(anf_node, proto);
|
||||
|
||||
// set outputs tensor
|
||||
SetNodeOutputs(anf_node, proto);
|
||||
|
||||
// set node attr
|
||||
SetNodeAttr(anf_node, proto);
|
||||
|
||||
MS_LOG(INFO) << "SetNodedefProto end!";
|
||||
}
|
||||
|
||||
bool CreateNodeDefBytes(const std::shared_ptr<AnfNode> &anf_node,
|
||||
const std::shared_ptr<AicpuOpKernelMod> &kernel_mod_ptr) {
|
||||
MS_LOG(INFO) << "CreateNodeDefBytes entry";
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
MS_LOG(INFO) << "CreateNodeDefBytes entry";
|
||||
|
||||
mindspore::NodeDef proto;
|
||||
|
||||
SetNodedefProto(anf_node, &proto);
|
||||
|
||||
std::string nodeDefStr;
|
||||
if (!proto.SerializeToString(&nodeDefStr)) {
|
||||
MS_LOG(ERROR) << "Serialize nodeDef to string failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
kernel_mod_ptr->SetNodeDef(nodeDefStr);
|
||||
|
||||
MS_LOG(INFO) << "CreateNodeDefBytes end!";
|
||||
return true;
|
||||
}
|
||||
|
@ -288,8 +293,8 @@ bool CreateNodeDefBytes(const std::shared_ptr<AnfNode> &anf_node,
|
|||
KernelModPtr AicpuOpBuild(const std::shared_ptr<AnfNode> &anf_node) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
|
||||
if (op_name == "InitDataSetQueue") {
|
||||
op_name = "InitData";
|
||||
if (op_name == kInitDataSetQueue) {
|
||||
op_name = kInitData;
|
||||
}
|
||||
auto kernel_mod_ptr = std::make_shared<AicpuOpKernelMod>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
|
||||
|
|
|
@ -110,8 +110,8 @@ bool AicpuOpKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
}
|
||||
|
||||
CreateCpuKernelInfo(inputs, outputs);
|
||||
if (node_name_ == "TopK") {
|
||||
node_name_ = "TopKV2";
|
||||
if (node_name_ == kTopK) {
|
||||
node_name_ = kTopKV2;
|
||||
}
|
||||
MS_LOG(INFO) << "Aicpu launch, node_so_:" << node_so_ << ", node name:" << node_name_
|
||||
<< ", args_size:" << args_.length();
|
||||
|
@ -141,8 +141,8 @@ std::vector<TaskInfoPtr> AicpuOpKernelMod::GenTask(const std::vector<AddressPtr>
|
|||
(void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs),
|
||||
[](const AddressPtr &output) -> void * { return output->addr; });
|
||||
|
||||
if (node_name_ == "TopK") {
|
||||
node_name_ = "TopKV2";
|
||||
if (node_name_ == kTopK) {
|
||||
node_name_ = kTopKV2;
|
||||
}
|
||||
AicpuTaskInfoPtr task_info_ptr = make_shared<ge::model_runner::AicpuTaskInfo>(
|
||||
stream_id, node_so_, node_name_, node_def_str_, input_data_addrs, output_data_addrs);
|
||||
|
|
|
@ -28,7 +28,6 @@ constexpr auto kInitDataSetQueue = "InitDataSetQueue";
|
|||
constexpr auto kInitData = "InitData";
|
||||
constexpr auto kGetNext = "GetNext";
|
||||
constexpr auto kPrint = "Print";
|
||||
|
||||
constexpr auto kOutputTypes = "output_types";
|
||||
constexpr auto kOutputShapes = "output_shapes";
|
||||
constexpr auto kChannelName = "channel_name";
|
||||
|
@ -36,9 +35,12 @@ constexpr auto kSharedName = "shared_name";
|
|||
constexpr auto kShapes = "shapes";
|
||||
constexpr auto kTypes = "types";
|
||||
constexpr auto kQueueName = "queue_name";
|
||||
|
||||
constexpr auto kSeed = "Seed0";
|
||||
constexpr auto kSeed2 = "Seed1";
|
||||
constexpr auto kSeed = "seed";
|
||||
constexpr auto kSeed0 = "Seed0";
|
||||
constexpr auto kSeed1 = "Seed1";
|
||||
constexpr auto kSeed2 = "seed2";
|
||||
constexpr auto kTopK = "TopK";
|
||||
constexpr auto kTopKV2 = "TopKV2";
|
||||
|
||||
struct AicpuParamHead {
|
||||
uint32_t length; // Total length: include cunstom message
|
||||
|
|
|
@ -95,12 +95,7 @@ class OpInfo {
|
|||
OpImplyType imply_type() const { return imply_type_; }
|
||||
std::string impl_path() const { return impl_path_; }
|
||||
std::string fusion_type() const { return fusion_type_; }
|
||||
bool async_flag() const { return async_flag_; }
|
||||
std::string binfile_name() const { return binfile_name_; }
|
||||
int compute_cost() const { return compute_cost_; }
|
||||
std::string kernel_name() const { return kernel_name_; }
|
||||
bool partial_flag() const { return partial_flag_; }
|
||||
bool dynamic_format() const { return dynamic_format_; }
|
||||
OpPattern op_pattern() const { return op_pattern_; }
|
||||
std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; }
|
||||
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; }
|
||||
|
@ -116,13 +111,10 @@ class OpInfo {
|
|||
void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; }
|
||||
void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; }
|
||||
void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; }
|
||||
void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; }
|
||||
void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; }
|
||||
void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); }
|
||||
void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); }
|
||||
void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); }
|
||||
void set_inputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>> &inputs) { inputs_ptr_ = inputs; }
|
||||
void set_outputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>> &outputs) { outputs_ptr_ = outputs; }
|
||||
bool is_ref() const { return !ref_infos_.empty(); }
|
||||
bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); }
|
||||
void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); }
|
||||
|
|
|
@ -103,6 +103,7 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p
|
|||
{kBroadcast, kBroadcastPattern},
|
||||
{kReduce, kReducePattern},
|
||||
{kDynamicFormat, kDynamicFormatPattern}};
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
op_info->set_async_flag(obj.at(kAsyncFlag));
|
||||
op_info->set_binfile_name(obj.at(kBinfileName));
|
||||
op_info->set_compute_cost(obj.at(kComputeCost));
|
||||
|
@ -199,6 +200,7 @@ bool OpLib::DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type,
|
|||
|
||||
bool OpLib::DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr<OpIOInfo> &op_io,
|
||||
size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(op_io);
|
||||
bool ret = true;
|
||||
try {
|
||||
std::vector<std::string> dtype;
|
||||
|
@ -218,6 +220,7 @@ bool OpLib::DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::sha
|
|||
|
||||
bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type,
|
||||
const std::shared_ptr<OpInfo> &op_info, const nlohmann::json &dtype_format) {
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
bool ret = true;
|
||||
try {
|
||||
std::shared_ptr<OpIOInfo> op_io = std::make_shared<OpIOInfo>();
|
||||
|
|
Loading…
Reference in New Issue