modify custom

This commit is contained in:
liuyu 2021-04-25 15:13:17 +08:00
parent 8de54f3355
commit 4e07e5bf18
8 changed files with 99 additions and 19 deletions

View File

@ -214,8 +214,10 @@ enum PrimType {
PrimType_ResizeGrad = 187,
PrimType_Splice = 188,
PrimType_LogSoftmax = 189,
PrimType_Call = 190,
PrimType_Custom = 191,
PrimType_MIN = PrimType_NONE,
PrimType_MAX = PrimType_LogSoftmax + 1
PrimType_MAX = PrimType_Custom + 1
};
void RegInfer(int prim_type, InferShape func);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,19 +15,41 @@
*/
#include "ops/custom.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "ops/op_utils.h"
#include <memory>
#include <map>
namespace mindspore {
namespace ops {
void Custom::Init(const std::vector<int64_t> &custom) { this->set_custom(custom); }
void Custom::Init(const std::string &type, const std::map<std::string, std::vector<uint8_t>> &attrs) {
this->set_type(type);
this->set_attr(attrs);
}
void Custom::set_custom(const std::vector<int64_t> &custom) { this->AddAttr(kCustom, MakeValue(custom)); }
void Custom::set_type(const std::string &type) { this->AddAttr(kType, MakeValue(type)); }
std::vector<int64_t> Custom::get_custom() const {
auto value_ptr = this->GetAttr(kCustom);
return GetValue<std::vector<int64_t>>(value_ptr);
std::string Custom::get_type() const {
auto value_ptr = this->GetAttr(kType);
return GetValue<std::string>(value_ptr);
}
void Custom::set_attr(const std::map<std::string, std::vector<uint8_t>> &attrs) {
ValuePtrList value_ptr_list;
for (const auto &attr : attrs) {
value_ptr_list.emplace_back(MakeValue<std::string>(attr.first));
value_ptr_list.emplace_back(MakeValue<std::vector<uint8_t>>(attr.second));
}
this->AddAttr(kAttr, MakeValue(value_ptr_list));
}
std::map<std::string, std::vector<uint8_t>> Custom::get_attr() const {
std::map<std::string, std::vector<uint8_t>> attrs;
auto value_ptr_list = GetValue<ValuePtrList>(this->GetAttr(kAttr));
for (size_t i = 0; i < value_ptr_list.size(); i += 2) {
auto key = GetValue<std::string>(value_ptr_list[i]);
auto value = GetValue<std::vector<uint8_t>>(value_ptr_list[i + 1]);
attrs[key] = value;
}
return attrs;
}
REGISTER_PRIMITIVE_C(kNameCustom, Custom);
} // namespace ops

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,12 +16,15 @@
#ifndef MINDSPORE_CORE_OPS_CUSTOM_H_
#define MINDSPORE_CORE_OPS_CUSTOM_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <map>
#include <memory>
#include <unordered_map>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"
#include "ir/anf.h"
namespace mindspore {
namespace ops {
@ -29,11 +32,13 @@ constexpr auto kNameCustom = "Custom";
class Custom : public PrimitiveC {
public:
Custom() : PrimitiveC(kNameCustom) {}
~Custom() = default;
~Custom() override = default;
MS_DECLARE_PARENT(Custom, PrimitiveC);
void Init(const std::vector<int64_t> &custom);
void set_custom(const std::vector<int64_t> &custom);
std::vector<int64_t> get_custom() const;
void Init(const std::string &type, const std::map<std::string, std::vector<uint8_t>> &attrs);
void set_type(const std::string &type);
std::string get_type() const;
void set_attr(const std::map<std::string, std::vector<uint8_t>> &attrs);
std::map<std::string, std::vector<uint8_t>> get_attr() const;
};
using PrimCustomPtr = std::shared_ptr<Custom>;

View File

@ -31,6 +31,7 @@ constexpr auto kActivation = "activation";
constexpr auto kActivationType = "activation_type";
constexpr auto kAddress = "address";
constexpr auto kAlignCorners = "align_corners";
constexpr auto kAttr = "attr";
constexpr auto kAspectRatios = "aspect_ratios";
constexpr auto kAxes = "axes";
constexpr auto kAxis = "axis";

View File

@ -208,6 +208,7 @@ union PrimitiveType {
Splice,
LogSoftmax,
Call,
Custom,
}
table Abs {
@ -1103,3 +1104,8 @@ table LogSoftmax {
table Call {
}
table Custom {
type: string;
attr: [Attribute];
}

View File

@ -142,3 +142,8 @@ table Vec {
table Vec2D {
data: [Vec];
}
table Attribute {
name: string;
data: [ubyte];
}

View File

@ -207,6 +207,7 @@ OP_TYPE(ResizeGrad)
OP_TYPE(Splice)
OP_TYPE(LogSoftmax)
OP_TYPE(Call)
OP_TYPE(Custom)
OP_TYPE_DEF_END(PrimitiveType)
OP_SCHEMA_DEF(Abs)
@ -1102,3 +1103,8 @@ OP_SCHEMA_DEF_END(LogSoftmax)
OP_SCHEMA_DEF(Call)
OP_SCHEMA_DEF_END(Call)
OP_SCHEMA_DEF_ONLY(Custom)
OP_ATTR_ONLY(type, string)
OP_ATTR_ONLY(attr, [Attribute])
OP_SCHEMA_DEF_ONLY_END(Custom)

View File

@ -975,6 +975,39 @@ RegistryMSOps g_erfPrimitiveCreatorRegistry("Erf", ErfPrimitiveCreator);
RegistryMSOps g_SplicePrimitiveCreatorRegistry("Splice", SplicePrimitiveCreator);
RegistryMSOps g_LogSoftmaxPrimitiveCreatorRegistry("LogSoftmax", LogSoftmaxPrimitiveCreator);
RegistryMSOps g_CallPrimitiveCreatorRegistry("call", CallPrimitiveCreator);
schema::PrimitiveT *CustomPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Custom>>(node);
auto *schema_op = new (std::nothrow) schema::CustomT();
if (schema_op == nullptr) {
return nullptr;
}
if (ms_primc->GetAttr("type") != nullptr) {
schema_op->type = ms_primc->get_type();
}
if (ms_primc->GetAttr("attr") != nullptr) {
auto attr_map = ms_primc->get_attr();
for (const auto &attr_item : attr_map) {
auto *attr = new (std::nothrow) schema::AttributeT();
if (attr == nullptr) {
return nullptr;
}
attr->name = attr_item.first;
attr->data = attr_item.second;
schema_op->attr.emplace_back(attr);
}
}
auto *prim = new (std::nothrow) schema::PrimitiveT();
if (prim == nullptr) {
return nullptr;
}
prim->value.value = schema_op;
prim->value.type = schema::PrimitiveType_Custom;
return prim;
}
RegistryMSOps g_CustomPrimitiveCreatorRegistry("Custom", CustomPrimitiveCreator);
} // namespace lite
} // namespace mindspore