forked from mindspore-Ecosystem/mindspore
modify custom
This commit is contained in:
parent
8de54f3355
commit
4e07e5bf18
|
@ -214,8 +214,10 @@ enum PrimType {
|
||||||
PrimType_ResizeGrad = 187,
|
PrimType_ResizeGrad = 187,
|
||||||
PrimType_Splice = 188,
|
PrimType_Splice = 188,
|
||||||
PrimType_LogSoftmax = 189,
|
PrimType_LogSoftmax = 189,
|
||||||
|
PrimType_Call = 190,
|
||||||
|
PrimType_Custom = 191,
|
||||||
PrimType_MIN = PrimType_NONE,
|
PrimType_MIN = PrimType_NONE,
|
||||||
PrimType_MAX = PrimType_LogSoftmax + 1
|
PrimType_MAX = PrimType_Custom + 1
|
||||||
};
|
};
|
||||||
|
|
||||||
void RegInfer(int prim_type, InferShape func);
|
void RegInfer(int prim_type, InferShape func);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -15,19 +15,41 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "ops/custom.h"
|
#include "ops/custom.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include <memory>
|
||||||
#include "abstract/primitive_infer_map.h"
|
#include <map>
|
||||||
#include "ops/op_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
void Custom::Init(const std::vector<int64_t> &custom) { this->set_custom(custom); }
|
void Custom::Init(const std::string &type, const std::map<std::string, std::vector<uint8_t>> &attrs) {
|
||||||
|
this->set_type(type);
|
||||||
|
this->set_attr(attrs);
|
||||||
|
}
|
||||||
|
|
||||||
void Custom::set_custom(const std::vector<int64_t> &custom) { this->AddAttr(kCustom, MakeValue(custom)); }
|
void Custom::set_type(const std::string &type) { this->AddAttr(kType, MakeValue(type)); }
|
||||||
|
|
||||||
std::vector<int64_t> Custom::get_custom() const {
|
std::string Custom::get_type() const {
|
||||||
auto value_ptr = this->GetAttr(kCustom);
|
auto value_ptr = this->GetAttr(kType);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::string>(value_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Custom::set_attr(const std::map<std::string, std::vector<uint8_t>> &attrs) {
|
||||||
|
ValuePtrList value_ptr_list;
|
||||||
|
for (const auto &attr : attrs) {
|
||||||
|
value_ptr_list.emplace_back(MakeValue<std::string>(attr.first));
|
||||||
|
value_ptr_list.emplace_back(MakeValue<std::vector<uint8_t>>(attr.second));
|
||||||
|
}
|
||||||
|
this->AddAttr(kAttr, MakeValue(value_ptr_list));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<std::string, std::vector<uint8_t>> Custom::get_attr() const {
|
||||||
|
std::map<std::string, std::vector<uint8_t>> attrs;
|
||||||
|
auto value_ptr_list = GetValue<ValuePtrList>(this->GetAttr(kAttr));
|
||||||
|
for (size_t i = 0; i < value_ptr_list.size(); i += 2) {
|
||||||
|
auto key = GetValue<std::string>(value_ptr_list[i]);
|
||||||
|
auto value = GetValue<std::vector<uint8_t>>(value_ptr_list[i + 1]);
|
||||||
|
attrs[key] = value;
|
||||||
|
}
|
||||||
|
return attrs;
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNameCustom, Custom);
|
REGISTER_PRIMITIVE_C(kNameCustom, Custom);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -16,12 +16,15 @@
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_CUSTOM_H_
|
#ifndef MINDSPORE_CORE_OPS_CUSTOM_H_
|
||||||
#define MINDSPORE_CORE_OPS_CUSTOM_H_
|
#define MINDSPORE_CORE_OPS_CUSTOM_H_
|
||||||
#include <memory>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <unordered_map>
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/primitive_c.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "ir/anf.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -29,11 +32,13 @@ constexpr auto kNameCustom = "Custom";
|
||||||
class Custom : public PrimitiveC {
|
class Custom : public PrimitiveC {
|
||||||
public:
|
public:
|
||||||
Custom() : PrimitiveC(kNameCustom) {}
|
Custom() : PrimitiveC(kNameCustom) {}
|
||||||
~Custom() = default;
|
~Custom() override = default;
|
||||||
MS_DECLARE_PARENT(Custom, PrimitiveC);
|
MS_DECLARE_PARENT(Custom, PrimitiveC);
|
||||||
void Init(const std::vector<int64_t> &custom);
|
void Init(const std::string &type, const std::map<std::string, std::vector<uint8_t>> &attrs);
|
||||||
void set_custom(const std::vector<int64_t> &custom);
|
void set_type(const std::string &type);
|
||||||
std::vector<int64_t> get_custom() const;
|
std::string get_type() const;
|
||||||
|
void set_attr(const std::map<std::string, std::vector<uint8_t>> &attrs);
|
||||||
|
std::map<std::string, std::vector<uint8_t>> get_attr() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
using PrimCustomPtr = std::shared_ptr<Custom>;
|
using PrimCustomPtr = std::shared_ptr<Custom>;
|
||||||
|
|
|
@ -31,6 +31,7 @@ constexpr auto kActivation = "activation";
|
||||||
constexpr auto kActivationType = "activation_type";
|
constexpr auto kActivationType = "activation_type";
|
||||||
constexpr auto kAddress = "address";
|
constexpr auto kAddress = "address";
|
||||||
constexpr auto kAlignCorners = "align_corners";
|
constexpr auto kAlignCorners = "align_corners";
|
||||||
|
constexpr auto kAttr = "attr";
|
||||||
constexpr auto kAspectRatios = "aspect_ratios";
|
constexpr auto kAspectRatios = "aspect_ratios";
|
||||||
constexpr auto kAxes = "axes";
|
constexpr auto kAxes = "axes";
|
||||||
constexpr auto kAxis = "axis";
|
constexpr auto kAxis = "axis";
|
||||||
|
|
|
@ -208,6 +208,7 @@ union PrimitiveType {
|
||||||
Splice,
|
Splice,
|
||||||
LogSoftmax,
|
LogSoftmax,
|
||||||
Call,
|
Call,
|
||||||
|
Custom,
|
||||||
}
|
}
|
||||||
|
|
||||||
table Abs {
|
table Abs {
|
||||||
|
@ -1103,3 +1104,8 @@ table LogSoftmax {
|
||||||
|
|
||||||
table Call {
|
table Call {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
table Custom {
|
||||||
|
type: string;
|
||||||
|
attr: [Attribute];
|
||||||
|
}
|
||||||
|
|
|
@ -142,3 +142,8 @@ table Vec {
|
||||||
table Vec2D {
|
table Vec2D {
|
||||||
data: [Vec];
|
data: [Vec];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
table Attribute {
|
||||||
|
name: string;
|
||||||
|
data: [ubyte];
|
||||||
|
}
|
||||||
|
|
|
@ -207,6 +207,7 @@ OP_TYPE(ResizeGrad)
|
||||||
OP_TYPE(Splice)
|
OP_TYPE(Splice)
|
||||||
OP_TYPE(LogSoftmax)
|
OP_TYPE(LogSoftmax)
|
||||||
OP_TYPE(Call)
|
OP_TYPE(Call)
|
||||||
|
OP_TYPE(Custom)
|
||||||
OP_TYPE_DEF_END(PrimitiveType)
|
OP_TYPE_DEF_END(PrimitiveType)
|
||||||
|
|
||||||
OP_SCHEMA_DEF(Abs)
|
OP_SCHEMA_DEF(Abs)
|
||||||
|
@ -1102,3 +1103,8 @@ OP_SCHEMA_DEF_END(LogSoftmax)
|
||||||
|
|
||||||
OP_SCHEMA_DEF(Call)
|
OP_SCHEMA_DEF(Call)
|
||||||
OP_SCHEMA_DEF_END(Call)
|
OP_SCHEMA_DEF_END(Call)
|
||||||
|
|
||||||
|
OP_SCHEMA_DEF_ONLY(Custom)
|
||||||
|
OP_ATTR_ONLY(type, string)
|
||||||
|
OP_ATTR_ONLY(attr, [Attribute])
|
||||||
|
OP_SCHEMA_DEF_ONLY_END(Custom)
|
||||||
|
|
|
@ -975,6 +975,39 @@ RegistryMSOps g_erfPrimitiveCreatorRegistry("Erf", ErfPrimitiveCreator);
|
||||||
RegistryMSOps g_SplicePrimitiveCreatorRegistry("Splice", SplicePrimitiveCreator);
|
RegistryMSOps g_SplicePrimitiveCreatorRegistry("Splice", SplicePrimitiveCreator);
|
||||||
RegistryMSOps g_LogSoftmaxPrimitiveCreatorRegistry("LogSoftmax", LogSoftmaxPrimitiveCreator);
|
RegistryMSOps g_LogSoftmaxPrimitiveCreatorRegistry("LogSoftmax", LogSoftmaxPrimitiveCreator);
|
||||||
RegistryMSOps g_CallPrimitiveCreatorRegistry("call", CallPrimitiveCreator);
|
RegistryMSOps g_CallPrimitiveCreatorRegistry("call", CallPrimitiveCreator);
|
||||||
|
|
||||||
|
schema::PrimitiveT *CustomPrimitiveCreator(const AnfNodePtr &node) {
|
||||||
|
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Custom>>(node);
|
||||||
|
auto *schema_op = new (std::nothrow) schema::CustomT();
|
||||||
|
if (schema_op == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (ms_primc->GetAttr("type") != nullptr) {
|
||||||
|
schema_op->type = ms_primc->get_type();
|
||||||
|
}
|
||||||
|
if (ms_primc->GetAttr("attr") != nullptr) {
|
||||||
|
auto attr_map = ms_primc->get_attr();
|
||||||
|
for (const auto &attr_item : attr_map) {
|
||||||
|
auto *attr = new (std::nothrow) schema::AttributeT();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
attr->name = attr_item.first;
|
||||||
|
attr->data = attr_item.second;
|
||||||
|
schema_op->attr.emplace_back(attr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto *prim = new (std::nothrow) schema::PrimitiveT();
|
||||||
|
if (prim == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
prim->value.value = schema_op;
|
||||||
|
prim->value.type = schema::PrimitiveType_Custom;
|
||||||
|
return prim;
|
||||||
|
}
|
||||||
|
|
||||||
|
RegistryMSOps g_CustomPrimitiveCreatorRegistry("Custom", CustomPrimitiveCreator);
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue