forked from mindspore-Ecosystem/mindspore
modify custom
This commit is contained in:
parent
8de54f3355
commit
4e07e5bf18
|
@ -214,8 +214,10 @@ enum PrimType {
|
|||
PrimType_ResizeGrad = 187,
|
||||
PrimType_Splice = 188,
|
||||
PrimType_LogSoftmax = 189,
|
||||
PrimType_Call = 190,
|
||||
PrimType_Custom = 191,
|
||||
PrimType_MIN = PrimType_NONE,
|
||||
PrimType_MAX = PrimType_LogSoftmax + 1
|
||||
PrimType_MAX = PrimType_Custom + 1
|
||||
};
|
||||
|
||||
void RegInfer(int prim_type, InferShape func);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -15,19 +15,41 @@
|
|||
*/
|
||||
|
||||
#include "ops/custom.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void Custom::Init(const std::vector<int64_t> &custom) { this->set_custom(custom); }
|
||||
void Custom::Init(const std::string &type, const std::map<std::string, std::vector<uint8_t>> &attrs) {
|
||||
this->set_type(type);
|
||||
this->set_attr(attrs);
|
||||
}
|
||||
|
||||
void Custom::set_custom(const std::vector<int64_t> &custom) { this->AddAttr(kCustom, MakeValue(custom)); }
|
||||
void Custom::set_type(const std::string &type) { this->AddAttr(kType, MakeValue(type)); }
|
||||
|
||||
std::vector<int64_t> Custom::get_custom() const {
|
||||
auto value_ptr = this->GetAttr(kCustom);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
std::string Custom::get_type() const {
|
||||
auto value_ptr = this->GetAttr(kType);
|
||||
return GetValue<std::string>(value_ptr);
|
||||
}
|
||||
|
||||
void Custom::set_attr(const std::map<std::string, std::vector<uint8_t>> &attrs) {
|
||||
ValuePtrList value_ptr_list;
|
||||
for (const auto &attr : attrs) {
|
||||
value_ptr_list.emplace_back(MakeValue<std::string>(attr.first));
|
||||
value_ptr_list.emplace_back(MakeValue<std::vector<uint8_t>>(attr.second));
|
||||
}
|
||||
this->AddAttr(kAttr, MakeValue(value_ptr_list));
|
||||
}
|
||||
|
||||
std::map<std::string, std::vector<uint8_t>> Custom::get_attr() const {
|
||||
std::map<std::string, std::vector<uint8_t>> attrs;
|
||||
auto value_ptr_list = GetValue<ValuePtrList>(this->GetAttr(kAttr));
|
||||
for (size_t i = 0; i < value_ptr_list.size(); i += 2) {
|
||||
auto key = GetValue<std::string>(value_ptr_list[i]);
|
||||
auto value = GetValue<std::vector<uint8_t>>(value_ptr_list[i + 1]);
|
||||
attrs[key] = value;
|
||||
}
|
||||
return attrs;
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameCustom, Custom);
|
||||
} // namespace ops
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,12 +16,15 @@
|
|||
|
||||
#ifndef MINDSPORE_CORE_OPS_CUSTOM_H_
|
||||
#define MINDSPORE_CORE_OPS_CUSTOM_H_
|
||||
#include <memory>
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -29,11 +32,13 @@ constexpr auto kNameCustom = "Custom";
|
|||
class Custom : public PrimitiveC {
|
||||
public:
|
||||
Custom() : PrimitiveC(kNameCustom) {}
|
||||
~Custom() = default;
|
||||
~Custom() override = default;
|
||||
MS_DECLARE_PARENT(Custom, PrimitiveC);
|
||||
void Init(const std::vector<int64_t> &custom);
|
||||
void set_custom(const std::vector<int64_t> &custom);
|
||||
std::vector<int64_t> get_custom() const;
|
||||
void Init(const std::string &type, const std::map<std::string, std::vector<uint8_t>> &attrs);
|
||||
void set_type(const std::string &type);
|
||||
std::string get_type() const;
|
||||
void set_attr(const std::map<std::string, std::vector<uint8_t>> &attrs);
|
||||
std::map<std::string, std::vector<uint8_t>> get_attr() const;
|
||||
};
|
||||
|
||||
using PrimCustomPtr = std::shared_ptr<Custom>;
|
||||
|
|
|
@ -31,6 +31,7 @@ constexpr auto kActivation = "activation";
|
|||
constexpr auto kActivationType = "activation_type";
|
||||
constexpr auto kAddress = "address";
|
||||
constexpr auto kAlignCorners = "align_corners";
|
||||
constexpr auto kAttr = "attr";
|
||||
constexpr auto kAspectRatios = "aspect_ratios";
|
||||
constexpr auto kAxes = "axes";
|
||||
constexpr auto kAxis = "axis";
|
||||
|
|
|
@ -208,6 +208,7 @@ union PrimitiveType {
|
|||
Splice,
|
||||
LogSoftmax,
|
||||
Call,
|
||||
Custom,
|
||||
}
|
||||
|
||||
table Abs {
|
||||
|
@ -1103,3 +1104,8 @@ table LogSoftmax {
|
|||
|
||||
table Call {
|
||||
}
|
||||
|
||||
table Custom {
|
||||
type: string;
|
||||
attr: [Attribute];
|
||||
}
|
||||
|
|
|
@ -142,3 +142,8 @@ table Vec {
|
|||
table Vec2D {
|
||||
data: [Vec];
|
||||
}
|
||||
|
||||
table Attribute {
|
||||
name: string;
|
||||
data: [ubyte];
|
||||
}
|
||||
|
|
|
@ -207,6 +207,7 @@ OP_TYPE(ResizeGrad)
|
|||
OP_TYPE(Splice)
|
||||
OP_TYPE(LogSoftmax)
|
||||
OP_TYPE(Call)
|
||||
OP_TYPE(Custom)
|
||||
OP_TYPE_DEF_END(PrimitiveType)
|
||||
|
||||
OP_SCHEMA_DEF(Abs)
|
||||
|
@ -1102,3 +1103,8 @@ OP_SCHEMA_DEF_END(LogSoftmax)
|
|||
|
||||
OP_SCHEMA_DEF(Call)
|
||||
OP_SCHEMA_DEF_END(Call)
|
||||
|
||||
OP_SCHEMA_DEF_ONLY(Custom)
|
||||
OP_ATTR_ONLY(type, string)
|
||||
OP_ATTR_ONLY(attr, [Attribute])
|
||||
OP_SCHEMA_DEF_ONLY_END(Custom)
|
||||
|
|
|
@ -975,6 +975,39 @@ RegistryMSOps g_erfPrimitiveCreatorRegistry("Erf", ErfPrimitiveCreator);
|
|||
RegistryMSOps g_SplicePrimitiveCreatorRegistry("Splice", SplicePrimitiveCreator);
|
||||
RegistryMSOps g_LogSoftmaxPrimitiveCreatorRegistry("LogSoftmax", LogSoftmaxPrimitiveCreator);
|
||||
RegistryMSOps g_CallPrimitiveCreatorRegistry("call", CallPrimitiveCreator);
|
||||
|
||||
schema::PrimitiveT *CustomPrimitiveCreator(const AnfNodePtr &node) {
|
||||
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Custom>>(node);
|
||||
auto *schema_op = new (std::nothrow) schema::CustomT();
|
||||
if (schema_op == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if (ms_primc->GetAttr("type") != nullptr) {
|
||||
schema_op->type = ms_primc->get_type();
|
||||
}
|
||||
if (ms_primc->GetAttr("attr") != nullptr) {
|
||||
auto attr_map = ms_primc->get_attr();
|
||||
for (const auto &attr_item : attr_map) {
|
||||
auto *attr = new (std::nothrow) schema::AttributeT();
|
||||
if (attr == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
attr->name = attr_item.first;
|
||||
attr->data = attr_item.second;
|
||||
schema_op->attr.emplace_back(attr);
|
||||
}
|
||||
}
|
||||
|
||||
auto *prim = new (std::nothrow) schema::PrimitiveT();
|
||||
if (prim == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
prim->value.value = schema_op;
|
||||
prim->value.type = schema::PrimitiveType_Custom;
|
||||
return prim;
|
||||
}
|
||||
|
||||
RegistryMSOps g_CustomPrimitiveCreatorRegistry("Custom", CustomPrimitiveCreator);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
Loading…
Reference in New Issue