modify custom

This commit is contained in:
liuyu 2021-04-25 15:13:17 +08:00
parent 8de54f3355
commit 4e07e5bf18
8 changed files with 99 additions and 19 deletions

View File

@ -214,8 +214,10 @@ enum PrimType {
PrimType_ResizeGrad = 187, PrimType_ResizeGrad = 187,
PrimType_Splice = 188, PrimType_Splice = 188,
PrimType_LogSoftmax = 189, PrimType_LogSoftmax = 189,
PrimType_Call = 190,
PrimType_Custom = 191,
PrimType_MIN = PrimType_NONE, PrimType_MIN = PrimType_NONE,
PrimType_MAX = PrimType_LogSoftmax + 1 PrimType_MAX = PrimType_Custom + 1
}; };
void RegInfer(int prim_type, InferShape func); void RegInfer(int prim_type, InferShape func);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,19 +15,41 @@
*/ */
#include "ops/custom.h" #include "ops/custom.h"
#include "utils/check_convert_utils.h" #include <memory>
#include "abstract/primitive_infer_map.h" #include <map>
#include "ops/op_utils.h"
namespace mindspore { namespace mindspore {
namespace ops { namespace ops {
void Custom::Init(const std::vector<int64_t> &custom) { this->set_custom(custom); } void Custom::Init(const std::string &type, const std::map<std::string, std::vector<uint8_t>> &attrs) {
this->set_type(type);
this->set_attr(attrs);
}
void Custom::set_custom(const std::vector<int64_t> &custom) { this->AddAttr(kCustom, MakeValue(custom)); } void Custom::set_type(const std::string &type) { this->AddAttr(kType, MakeValue(type)); }
std::vector<int64_t> Custom::get_custom() const { std::string Custom::get_type() const {
auto value_ptr = this->GetAttr(kCustom); auto value_ptr = this->GetAttr(kType);
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::string>(value_ptr);
}
void Custom::set_attr(const std::map<std::string, std::vector<uint8_t>> &attrs) {
ValuePtrList value_ptr_list;
for (const auto &attr : attrs) {
value_ptr_list.emplace_back(MakeValue<std::string>(attr.first));
value_ptr_list.emplace_back(MakeValue<std::vector<uint8_t>>(attr.second));
}
this->AddAttr(kAttr, MakeValue(value_ptr_list));
}
std::map<std::string, std::vector<uint8_t>> Custom::get_attr() const {
std::map<std::string, std::vector<uint8_t>> attrs;
auto value_ptr_list = GetValue<ValuePtrList>(this->GetAttr(kAttr));
for (size_t i = 0; i < value_ptr_list.size(); i += 2) {
auto key = GetValue<std::string>(value_ptr_list[i]);
auto value = GetValue<std::vector<uint8_t>>(value_ptr_list[i + 1]);
attrs[key] = value;
}
return attrs;
} }
REGISTER_PRIMITIVE_C(kNameCustom, Custom); REGISTER_PRIMITIVE_C(kNameCustom, Custom);
} // namespace ops } // namespace ops

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,12 +16,15 @@
#ifndef MINDSPORE_CORE_OPS_CUSTOM_H_ #ifndef MINDSPORE_CORE_OPS_CUSTOM_H_
#define MINDSPORE_CORE_OPS_CUSTOM_H_ #define MINDSPORE_CORE_OPS_CUSTOM_H_
#include <memory> #include <string>
#include <utility>
#include <vector> #include <vector>
#include <map>
#include <memory>
#include <unordered_map>
#include "ops/primitive_c.h" #include "ops/primitive_c.h"
#include "abstract/abstract_value.h" #include "ops/op_utils.h"
#include "utils/check_convert_utils.h" #include "ir/anf.h"
namespace mindspore { namespace mindspore {
namespace ops { namespace ops {
@ -29,11 +32,13 @@ constexpr auto kNameCustom = "Custom";
class Custom : public PrimitiveC { class Custom : public PrimitiveC {
public: public:
Custom() : PrimitiveC(kNameCustom) {} Custom() : PrimitiveC(kNameCustom) {}
~Custom() = default; ~Custom() override = default;
MS_DECLARE_PARENT(Custom, PrimitiveC); MS_DECLARE_PARENT(Custom, PrimitiveC);
void Init(const std::vector<int64_t> &custom); void Init(const std::string &type, const std::map<std::string, std::vector<uint8_t>> &attrs);
void set_custom(const std::vector<int64_t> &custom); void set_type(const std::string &type);
std::vector<int64_t> get_custom() const; std::string get_type() const;
void set_attr(const std::map<std::string, std::vector<uint8_t>> &attrs);
std::map<std::string, std::vector<uint8_t>> get_attr() const;
}; };
using PrimCustomPtr = std::shared_ptr<Custom>; using PrimCustomPtr = std::shared_ptr<Custom>;

View File

@ -31,6 +31,7 @@ constexpr auto kActivation = "activation";
constexpr auto kActivationType = "activation_type"; constexpr auto kActivationType = "activation_type";
constexpr auto kAddress = "address"; constexpr auto kAddress = "address";
constexpr auto kAlignCorners = "align_corners"; constexpr auto kAlignCorners = "align_corners";
constexpr auto kAttr = "attr";
constexpr auto kAspectRatios = "aspect_ratios"; constexpr auto kAspectRatios = "aspect_ratios";
constexpr auto kAxes = "axes"; constexpr auto kAxes = "axes";
constexpr auto kAxis = "axis"; constexpr auto kAxis = "axis";

View File

@ -208,6 +208,7 @@ union PrimitiveType {
Splice, Splice,
LogSoftmax, LogSoftmax,
Call, Call,
Custom,
} }
table Abs { table Abs {
@ -1103,3 +1104,8 @@ table LogSoftmax {
table Call { table Call {
} }
table Custom {
type: string;
attr: [Attribute];
}

View File

@ -142,3 +142,8 @@ table Vec {
table Vec2D { table Vec2D {
data: [Vec]; data: [Vec];
} }
table Attribute {
name: string;
data: [ubyte];
}

View File

@ -207,6 +207,7 @@ OP_TYPE(ResizeGrad)
OP_TYPE(Splice) OP_TYPE(Splice)
OP_TYPE(LogSoftmax) OP_TYPE(LogSoftmax)
OP_TYPE(Call) OP_TYPE(Call)
OP_TYPE(Custom)
OP_TYPE_DEF_END(PrimitiveType) OP_TYPE_DEF_END(PrimitiveType)
OP_SCHEMA_DEF(Abs) OP_SCHEMA_DEF(Abs)
@ -1102,3 +1103,8 @@ OP_SCHEMA_DEF_END(LogSoftmax)
OP_SCHEMA_DEF(Call) OP_SCHEMA_DEF(Call)
OP_SCHEMA_DEF_END(Call) OP_SCHEMA_DEF_END(Call)
OP_SCHEMA_DEF_ONLY(Custom)
OP_ATTR_ONLY(type, string)
OP_ATTR_ONLY(attr, [Attribute])
OP_SCHEMA_DEF_ONLY_END(Custom)

View File

@ -975,6 +975,39 @@ RegistryMSOps g_erfPrimitiveCreatorRegistry("Erf", ErfPrimitiveCreator);
RegistryMSOps g_SplicePrimitiveCreatorRegistry("Splice", SplicePrimitiveCreator); RegistryMSOps g_SplicePrimitiveCreatorRegistry("Splice", SplicePrimitiveCreator);
RegistryMSOps g_LogSoftmaxPrimitiveCreatorRegistry("LogSoftmax", LogSoftmaxPrimitiveCreator); RegistryMSOps g_LogSoftmaxPrimitiveCreatorRegistry("LogSoftmax", LogSoftmaxPrimitiveCreator);
RegistryMSOps g_CallPrimitiveCreatorRegistry("call", CallPrimitiveCreator); RegistryMSOps g_CallPrimitiveCreatorRegistry("call", CallPrimitiveCreator);
schema::PrimitiveT *CustomPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Custom>>(node);
auto *schema_op = new (std::nothrow) schema::CustomT();
if (schema_op == nullptr) {
return nullptr;
}
if (ms_primc->GetAttr("type") != nullptr) {
schema_op->type = ms_primc->get_type();
}
if (ms_primc->GetAttr("attr") != nullptr) {
auto attr_map = ms_primc->get_attr();
for (const auto &attr_item : attr_map) {
auto *attr = new (std::nothrow) schema::AttributeT();
if (attr == nullptr) {
return nullptr;
}
attr->name = attr_item.first;
attr->data = attr_item.second;
schema_op->attr.emplace_back(attr);
}
}
auto *prim = new (std::nothrow) schema::PrimitiveT();
if (prim == nullptr) {
return nullptr;
}
prim->value.value = schema_op;
prim->value.type = schema::PrimitiveType_Custom;
return prim;
}
RegistryMSOps g_CustomPrimitiveCreatorRegistry("Custom", CustomPrimitiveCreator);
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore