forked from mindspore-Ecosystem/mindspore
add UnPack method in ops & remove anf_importer populater
This commit is contained in:
parent
07a75658bf
commit
28e3508718
|
@ -6,7 +6,7 @@ if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_
|
|||
endif ()
|
||||
|
||||
set(MS_VERSION_MAJOY 0)
|
||||
set(MS_VERSION_MINOR 6)
|
||||
set(MS_VERSION_MINOR 7)
|
||||
set(MS_VERSION_REVISION 0)
|
||||
|
||||
set(DIR_PREFIX mindspore-lite)
|
||||
|
|
|
@ -5,13 +5,13 @@ BASE_PATH=$(cd "$(dirname $0)"; pwd)
|
|||
TOP_PATH="${BASE_PATH}/../../.."
|
||||
# build mindspore-lite arm64
|
||||
cd ${TOP_PATH}
|
||||
#bash build.sh -I arm64
|
||||
#COMPILE_RET=$?
|
||||
bash build.sh -I arm64
|
||||
COMPILE_RET=$?
|
||||
|
||||
#if [[ "${COMPILE_RET}" -ne 0 ]]; then
|
||||
# echo "---------------- mindspore lite: build failed ----------------"
|
||||
# exit
|
||||
#fi
|
||||
if [[ "${COMPILE_RET}" -ne 0 ]]; then
|
||||
echo "---------------- mindspore lite: build failed ----------------"
|
||||
exit
|
||||
fi
|
||||
|
||||
# copy arm64 so
|
||||
cd ${TOP_PATH}/output/
|
||||
|
|
|
@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.14)
|
|||
project (Lite-java)
|
||||
|
||||
set(MS_VERSION_MAJOY 0)
|
||||
set(MS_VERSION_MINOR 6)
|
||||
set(MS_VERSION_MINOR 7)
|
||||
set(MS_VERSION_REVISION 0)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOY=${MS_VERSION_MAJOY} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOY=${MS_VERSION_MAJOY} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
|
||||
|
|
|
@ -32,6 +32,7 @@ std::vector<size_t> GetGraphInputNodes(const schema::MetaGraph *meta_graph) {
|
|||
for (size_t j = 0; j < meta_graph->nodes()->size(); j++) {
|
||||
auto *cNode = meta_graph->nodes()->GetAs<schema::CNode>(j);
|
||||
MS_ASSERT(nullptr != cNode);
|
||||
MS_ASSERT(nullptr != cNode->inputIndex());
|
||||
for (size_t k = 0; k < cNode->inputIndex()->size(); k++) {
|
||||
if (cNode->inputIndex()->GetAs<uint32_t>(k) == input_index) {
|
||||
if (!IsContain<size_t>(ret, j)) {
|
||||
|
|
|
@ -53,6 +53,7 @@ int Executor::Run(std::vector<tensor::Tensor *> &in_tensors, std::vector<tensor:
|
|||
MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name();
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (after != nullptr) {
|
||||
if (!after(PackToMSTensors(kernel->in_tensors()), PackToMSTensors(kernel->out_tensors()),
|
||||
{kernel->name(), kernel->type_str()})) {
|
||||
|
|
|
@ -188,186 +188,6 @@ void ModelImpl::FreeMetaGraph() {
|
|||
|
||||
const schema::MetaGraph *ModelImpl::meta_graph() const { return this->meta_graph_; }
|
||||
|
||||
PrimitiveC *ModelImpl::CopyPrimitive(const schema::Primitive *src_prim) {
|
||||
MS_EXCEPTION_IF_NULL(src_prim);
|
||||
auto op_type = src_prim->value_type();
|
||||
switch (op_type) {
|
||||
case schema::PrimitiveType_SoftMax:
|
||||
return new SoftMax(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Activation:
|
||||
return new Activation(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Conv2D:
|
||||
return new Conv2D(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_DeConv2D:
|
||||
return new DeConv2D(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Reduce:
|
||||
return new Reduce(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Pooling:
|
||||
return new Pooling(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_DepthwiseConv2D:
|
||||
return new DepthwiseConv2D(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_FusedBatchNorm:
|
||||
return new FusedBatchNorm(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_BatchNorm:
|
||||
return new BatchNorm(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_FullConnection:
|
||||
return new FullConnection(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Power:
|
||||
return new Power(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Range:
|
||||
return new Range(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Mul:
|
||||
return new Mul(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Add:
|
||||
return new Add(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Sub:
|
||||
return new Sub(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Div:
|
||||
return new Div(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_BiasAdd:
|
||||
return new BiasAdd(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_ExpandDims:
|
||||
return new ExpandDims(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_ArgMax:
|
||||
return new ArgMax(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_ArgMin:
|
||||
return new ArgMin(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Cast:
|
||||
return new Cast(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Reshape:
|
||||
return new Reshape(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Scale:
|
||||
return new Scale(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Eltwise:
|
||||
return new Eltwise(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Concat:
|
||||
return new Concat(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Fill:
|
||||
return new Fill(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Transpose:
|
||||
return new Transpose(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Slice:
|
||||
return new SliceOp(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Squeeze:
|
||||
return new Squeeze(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Nchw2Nhwc:
|
||||
return new Nchw2Nhwc(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Nhwc2Nchw:
|
||||
return new Nhwc2Nchw(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Flatten:
|
||||
return new Flatten(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Mean:
|
||||
return new Mean(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Stack:
|
||||
return new Stack(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Crop:
|
||||
return new Crop(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_SquaredDifference:
|
||||
return new SquaredDifference(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_AddN:
|
||||
return new AddN(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Abs:
|
||||
return new Abs(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Sin:
|
||||
return new Sin(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Cos:
|
||||
return new Cos(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Log:
|
||||
return new Log(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Sqrt:
|
||||
return new Sqrt(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Rsqrt:
|
||||
return new Rsqrt(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Square:
|
||||
return new Square(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Exp:
|
||||
return new Exp(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Gather:
|
||||
return new Gather(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_GatherNd:
|
||||
return new GatherNd(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_LocalResponseNormalization:
|
||||
return new LocalResponseNormalization(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Maximum:
|
||||
return new Maximum(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Minimum:
|
||||
return new Minimum(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Pad:
|
||||
return new Pad(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_StridedSlice:
|
||||
return new StridedSlice(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Prelu:
|
||||
return new Prelu(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_CaffePReLU:
|
||||
return new CaffePReLU(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Round:
|
||||
return new Round(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Reverse:
|
||||
return new Reverse(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_ReverseSequence:
|
||||
return new ReverseSequence(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_LogicalAnd:
|
||||
return new LogicalAnd(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_LogicalOr:
|
||||
return new LogicalOr(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_LogicalNot:
|
||||
return new LogicalNot(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_FloorDiv:
|
||||
return new FloorDiv(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_FloorMod:
|
||||
return new FloorMod(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Equal:
|
||||
return new Equal(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_NotEqual:
|
||||
return new NotEqual(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Less:
|
||||
return new Less(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_LessEqual:
|
||||
return new LessEqual(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Greater:
|
||||
return new Greater(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_GreaterEqual:
|
||||
return new GreaterEqual(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Floor:
|
||||
return new Floor(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Ceil:
|
||||
return new Ceil(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Split:
|
||||
return new Split(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_OneHot:
|
||||
return new OneHot(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_SpaceToDepth:
|
||||
return new SpaceToDepth(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Tile:
|
||||
return new Tile(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Resize:
|
||||
return new Resize(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Unstack:
|
||||
return new Unstack(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Unique:
|
||||
return new Unique(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_TopK:
|
||||
return new TopK(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_MatMul:
|
||||
return new MatMul(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_QuantDTypeCast:
|
||||
return new QuantDTypeCast(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_EmbeddingLookup:
|
||||
return new EmbeddingLookup(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Elu:
|
||||
return new Elu(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_DeDepthwiseConv2D:
|
||||
return new DeDepthwiseConv2D(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Shape:
|
||||
return new Shape(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Unsqueeze:
|
||||
return new Unsqueeze(const_cast<schema::Primitive *>(src_prim));
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int ModelImpl::BuildOps() {
|
||||
if (this->meta_graph_ == nullptr) {
|
||||
MS_LOG(ERROR) << "mete_graph is nullptr";
|
||||
|
@ -379,7 +199,7 @@ int ModelImpl::BuildOps() {
|
|||
auto name = cNode->name()->str();
|
||||
auto srcPrim = cNode->primitive();
|
||||
|
||||
this->ops_[name] = CopyPrimitive(srcPrim);
|
||||
this->ops_[name] = PrimitiveC::UnPackFromSchemaPrimitive(const_cast<schema::Primitive *>(srcPrim));
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -33,9 +33,11 @@ namespace lite {
|
|||
class Abs : public ArithmeticSelf {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Abs() = default;
|
||||
explicit Abs(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Abs(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "src/ops/activation.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -25,6 +26,21 @@ float Activation::GetAlpha() const { return this->primitive_->value.AsActivation
|
|||
void Activation::SetType(int type) { this->primitive_->value.AsActivation()->type = (schema::ActivationType)type; }
|
||||
void Activation::SetAlpha(float alpha) { this->primitive_->value.AsActivation()->alpha = alpha; }
|
||||
|
||||
int Activation::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
this->primitive_ = new (schema::PrimitiveT);
|
||||
auto attr = std::make_unique<schema::ActivationT>();
|
||||
if (prim.name() == "ReLU") {
|
||||
attr->type = schema::ActivationType_RELU;
|
||||
} else if (prim.name() == "Sigmoid") {
|
||||
attr->type = schema::ActivationType_SIGMOID;
|
||||
} else if (prim.name() == "ReLU6") {
|
||||
attr->type = schema::ActivationType_RELU6;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_Activation;
|
||||
this->primitive_->value.value = attr.release();
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
#else
|
||||
|
||||
int Activation::GetType() const { return this->primitive_->value_as_Activation()->type(); }
|
||||
|
|
|
@ -27,9 +27,12 @@ namespace lite {
|
|||
class Activation : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Activation() = default;
|
||||
explicit Activation(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
#else
|
||||
explicit Activation(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
int GetType() const;
|
||||
float GetAlpha() const;
|
||||
void SetType(int type);
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class ActivationGrad : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
ActivationGrad() = default;
|
||||
explicit ActivationGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit ActivationGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int GetType() const;
|
||||
void SetType(int type);
|
||||
};
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "src/ops/add.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -25,6 +26,29 @@ void Add::SetActivationType(int activation_type) {
|
|||
this->primitive_->value.AsAdd()->activationType = (schema::ActivationType)activation_type;
|
||||
}
|
||||
|
||||
int Add::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_Add;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_Add) {
|
||||
MS_LOG(ERROR) << "Primitive type should be add";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
this->primitive_->value.value = new (std::nothrow) schema::AddT();
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
int Add::GetActivationType() const { return this->primitive_->value_as_Add()->activationType(); }
|
||||
|
|
|
@ -33,10 +33,12 @@ namespace lite {
|
|||
class Add : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Add() = default;
|
||||
explicit Add(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
#else
|
||||
explicit Add(schema::Primitive *primitive) : Arithmetic(primitive) {}
|
||||
|
||||
#endif
|
||||
int GetActivationType() const;
|
||||
void SetActivationType(int activation_type);
|
||||
};
|
||||
|
|
|
@ -28,9 +28,11 @@ namespace lite {
|
|||
class AddN : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
AddN() = default;
|
||||
explicit AddN(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit AddN(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetN() const;
|
||||
void SetN(int n);
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class ArgMax : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
ArgMax() = default;
|
||||
explicit ArgMax(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit ArgMax(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetAxis() const;
|
||||
bool GetOutMaxValue() const;
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class ArgMin : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
ArgMin() = default;
|
||||
explicit ArgMin(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit ArgMin(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetAxis() const;
|
||||
bool GetOutMaxValue() const;
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class Arithmetic : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Arithmetic() = default;
|
||||
explicit Arithmetic(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Arithmetic(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
bool Broadcasting() { return this->broadcasting_; }
|
||||
int NDims() { return this->ndim_; }
|
||||
|
|
|
@ -25,10 +25,11 @@ namespace lite {
|
|||
class ArithmeticSelf : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
ArithmeticSelf() = default;
|
||||
explicit ArithmeticSelf(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit ArithmeticSelf(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "src/ops/batch_norm.h"
|
||||
|
||||
#include <memory>
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -23,6 +23,15 @@ float BatchNorm::GetEpsilon() const { return this->primitive_->value.AsBatchNorm
|
|||
|
||||
void BatchNorm::SetEpsilon(float epsilon) { this->primitive_->value.AsBatchNorm()->epsilon = epsilon; }
|
||||
|
||||
int BatchNorm::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
this->primitive_ = new (schema::PrimitiveT);
|
||||
auto attr = std::make_unique<schema::FusedBatchNormT>();
|
||||
attr->epsilon = GetValue<float>(prim.GetAttr("epsilon"));
|
||||
this->primitive_->value.type = schema::PrimitiveType_FusedBatchNorm;
|
||||
this->primitive_->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
float BatchNorm::GetEpsilon() const { return this->primitive_->value_as_BatchNorm()->epsilon(); }
|
||||
|
|
|
@ -28,10 +28,12 @@ namespace lite {
|
|||
class BatchNorm : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
BatchNorm() = default;
|
||||
explicit BatchNorm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
#else
|
||||
explicit BatchNorm(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
float GetEpsilon() const;
|
||||
void SetEpsilon(float epsilon);
|
||||
};
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class BatchToSpace : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
BatchToSpace() = default;
|
||||
explicit BatchToSpace(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit BatchToSpace(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
std::vector<int> GetBlockShape() const;
|
||||
std::vector<int> GetCrops() const;
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "src/ops/bias_add.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -23,6 +24,16 @@ std::vector<int> BiasAdd::GetAxis() const { return this->primitive_->value.AsBia
|
|||
|
||||
void BiasAdd::SetAxis(const std::vector<int> &axis) { this->primitive_->value.AsBiasAdd()->axis = axis; }
|
||||
|
||||
int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
this->primitive_ = new (schema::PrimitiveT);
|
||||
auto attr = std::make_unique<schema::BiasAddT>();
|
||||
attr->axis = {0};
|
||||
this->primitive_->value.type = schema::PrimitiveType_BiasAdd;
|
||||
this->primitive_->value.value = attr.release();
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
std::vector<int> BiasAdd::GetAxis() const {
|
||||
|
|
|
@ -28,10 +28,12 @@ namespace lite {
|
|||
class BiasAdd : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
BiasAdd() = default;
|
||||
explicit BiasAdd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
#else
|
||||
explicit BiasAdd(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
std::vector<int> GetAxis() const;
|
||||
void SetAxis(const std::vector<int> &axis);
|
||||
};
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class BiasGrad : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
BiasGrad() = default;
|
||||
explicit BiasGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit BiasGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
std::vector<int> GetAxis() const;
|
||||
void SetAxis(const std::vector<int> &axis);
|
||||
};
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class BNGradInput : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
BNGradInput() = default;
|
||||
explicit BNGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit BNGradInput(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
float GetEps() const;
|
||||
int GetChannels() const;
|
||||
void SetEps(float eps);
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class BroadcastTo : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
BroadcastTo() = default;
|
||||
explicit BroadcastTo(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit BroadcastTo(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
std::vector<int> GetDstShape() const;
|
||||
void SetDstShape(const std::vector<int> &dst_shape);
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class CaffePReLU : public Activation {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
CaffePReLU() = default;
|
||||
explicit CaffePReLU(schema::PrimitiveT *primitive) : Activation(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit CaffePReLU(schema::Primitive *primitive) : Activation(primitive) {}
|
||||
|
||||
#endif
|
||||
bool GetChannelShared() const;
|
||||
void SetChannelShared(bool channel_shared);
|
||||
};
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class Cast : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Cast() = default;
|
||||
explicit Cast(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Cast(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetSrcT() const;
|
||||
int GetDstT() const;
|
||||
|
|
|
@ -28,9 +28,11 @@ namespace lite {
|
|||
class Ceil : public ArithmeticSelf {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Ceil() = default;
|
||||
explicit Ceil(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Ceil(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class Clip : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Clip() = default;
|
||||
explicit Clip(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Clip(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
float GetMax() const;
|
||||
float GetMin() const;
|
||||
void SetMax(float max);
|
||||
|
|
|
@ -15,9 +15,11 @@
|
|||
*/
|
||||
|
||||
#include "src/ops/concat.h"
|
||||
#include <memory>
|
||||
#include "include/errorcode.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "src/ir/tensor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -27,6 +29,16 @@ int Concat::GetN() const { return this->primitive_->value.AsConcat()->n; }
|
|||
void Concat::SetAxis(int axis) { this->primitive_->value.AsConcat()->axis = axis; }
|
||||
void Concat::SetN(int n) { this->primitive_->value.AsConcat()->n = n; }
|
||||
|
||||
int Concat::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
this->primitive_ = new (schema::PrimitiveT);
|
||||
auto attr = std::make_unique<schema::ConcatT>();
|
||||
auto prim_axis = GetValue<int>(prim.GetAttr("axis"));
|
||||
attr->axis = prim_axis;
|
||||
this->primitive_->value.type = schema::PrimitiveType_Concat;
|
||||
this->primitive_->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
int Concat::GetAxis() const { return this->primitive_->value_as_Concat()->axis(); }
|
||||
|
|
|
@ -28,10 +28,12 @@ namespace lite {
|
|||
class Concat : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Concat() = default;
|
||||
explicit Concat(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
#else
|
||||
explicit Concat(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetAxis() const;
|
||||
int GetN() const;
|
||||
|
|
|
@ -28,9 +28,11 @@ namespace lite {
|
|||
class ConstantOfShape : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
ConstantOfShape() = default;
|
||||
explicit ConstantOfShape(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit ConstantOfShape(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
float GetValue() const;
|
||||
void SetValue(float value);
|
||||
|
|
|
@ -15,9 +15,14 @@
|
|||
*/
|
||||
|
||||
#include "src/ops/conv2d.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "include/errorcode.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "src/ir/tensor.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -63,6 +68,265 @@ void Conv2D::SetHasBias(bool has_bias) { this->primitive_->value.AsConv2D()->has
|
|||
void Conv2D::SetActivationType(int activation_type) {
|
||||
this->primitive_->value.AsConv2D()->activationType = (schema::ActivationType)activation_type;
|
||||
}
|
||||
template <typename T>
|
||||
void ConvertConvWeight(const ParameterPtr ¶m_node) {
|
||||
MS_ASSERT(param_node != nullptr);
|
||||
auto param = param_node->default_param();
|
||||
auto weight = std::dynamic_pointer_cast<ParamValueLite>(param);
|
||||
MS_ASSERT(weight != nullptr);
|
||||
|
||||
std::unique_ptr<T> buf(new (std::nothrow) T[weight->tensor_shape_size()]);
|
||||
if (buf == nullptr) {
|
||||
MS_LOG(ERROR) << "new buf failed";
|
||||
return;
|
||||
}
|
||||
|
||||
size_t filter_k = weight->tensor_shape()[0];
|
||||
size_t filter_c = weight->tensor_shape()[1];
|
||||
size_t filter_h = weight->tensor_shape()[2];
|
||||
size_t filter_w = weight->tensor_shape()[3];
|
||||
T *p1Buff = nullptr;
|
||||
T *p2Buff = nullptr;
|
||||
for (size_t k = 0; k < filter_k; ++k) {
|
||||
for (size_t c = 0; c < filter_c; ++c) {
|
||||
for (size_t h = 0; h < filter_h; ++h) {
|
||||
for (size_t w = 0; w < filter_w; ++w) {
|
||||
p1Buff = reinterpret_cast<float *>(weight->tensor_addr()) +
|
||||
((k * filter_c * filter_h * filter_w) + (c * filter_h * filter_w) + (h * filter_w) + (w));
|
||||
p2Buff =
|
||||
buf.get() + ((c * filter_k * filter_h * filter_w) + (k * filter_h * filter_w) + (h * filter_w) + (w));
|
||||
*p2Buff = *p1Buff;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto ret = ::memcpy_s(weight->tensor_addr(), weight->tensor_shape_size() * sizeof(T), buf.get(),
|
||||
weight->tensor_shape_size() * sizeof(T));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s failed: " << ret;
|
||||
return;
|
||||
}
|
||||
|
||||
auto abstract_base = param_node->abstract();
|
||||
MS_ASSERT(abstract_base != nullptr);
|
||||
if (utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
|
||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
|
||||
utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape()[0] = filter_c;
|
||||
utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape()[1] = filter_k;
|
||||
utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape()[2] = filter_h;
|
||||
utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape()[3] = filter_w;
|
||||
}
|
||||
return;
|
||||
}
|
||||
void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group,
|
||||
const std::vector<AnfNodePtr> &inputs) {
|
||||
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
|
||||
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
|
||||
if (format == "NCHW") {
|
||||
attr->format = schema::Format_NCHW;
|
||||
} else if (format == "NHWC") {
|
||||
attr->format = schema::Format_NHWC;
|
||||
} else {
|
||||
attr->format = schema::Format_NUM_OF_FORMAT;
|
||||
}
|
||||
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
|
||||
attr->padUp = pad_list[0];
|
||||
attr->padDown = pad_list[1];
|
||||
attr->padLeft = pad_list[2];
|
||||
attr->padRight = pad_list[3];
|
||||
|
||||
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
|
||||
attr->dilateH = dilation[0];
|
||||
attr->dilateW = dilation[1];
|
||||
|
||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
|
||||
attr->kernelH = kernel_size[0];
|
||||
attr->kernelW = kernel_size[1];
|
||||
|
||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
|
||||
attr->strideH = stride[2];
|
||||
attr->strideW = stride[3];
|
||||
|
||||
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
||||
if (pad_mode == "valid") {
|
||||
attr->padMode = schema::PadMode_VALID;
|
||||
} else if (pad_mode == "same") {
|
||||
attr->padMode = schema::PadMode_SAME;
|
||||
} else {
|
||||
attr->padMode = schema::PadMode_NOTSET;
|
||||
}
|
||||
|
||||
int channel_mutiplier = 1;
|
||||
if (prim.GetAttr("channel_mutiplier") != nullptr) {
|
||||
channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier"));
|
||||
}
|
||||
attr->channelMultiplier = channel_mutiplier;
|
||||
|
||||
MS_ASSERT(inputs.size() == kAnfPopulaterTwo);
|
||||
auto input_node = inputs[kAnfPopulaterOne];
|
||||
MS_ASSERT(input_node != nullptr);
|
||||
if (input_node->isa<Parameter>()) {
|
||||
auto param_node = input_node->cast<ParameterPtr>();
|
||||
ConvertConvWeight<float>(param_node);
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
||||
primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
void Conv2D::PopulaterConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group) {
|
||||
auto attr = std::make_unique<schema::Conv2DT>();
|
||||
attr->group = group;
|
||||
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
|
||||
if (format == "NCHW") {
|
||||
attr->format = schema::Format_NCHW;
|
||||
} else if (format == "NHWC") {
|
||||
attr->format = schema::Format_NHWC;
|
||||
} else {
|
||||
attr->format = schema::Format_NUM_OF_FORMAT;
|
||||
}
|
||||
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
|
||||
attr->padUp = pad_list[0];
|
||||
attr->padDown = pad_list[1];
|
||||
attr->padLeft = pad_list[2];
|
||||
attr->padRight = pad_list[3];
|
||||
|
||||
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
|
||||
attr->dilateH = dilation[0];
|
||||
attr->dilateW = dilation[1];
|
||||
|
||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
|
||||
attr->kernelH = kernel_size[0];
|
||||
attr->kernelW = kernel_size[1];
|
||||
|
||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
|
||||
attr->strideH = stride[2];
|
||||
attr->strideW = stride[3];
|
||||
|
||||
attr->channelOut = GetValue<int>(prim.GetAttr("out_channel"));
|
||||
|
||||
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
||||
if (pad_mode == "valid") {
|
||||
attr->padMode = schema::PadMode_VALID;
|
||||
} else if (pad_mode == "same") {
|
||||
attr->padMode = schema::PadMode_SAME;
|
||||
} else {
|
||||
attr->padMode = schema::PadMode_NOTSET;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||
primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
void Conv2D::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) {
|
||||
constexpr float qmin = 0;
|
||||
constexpr float qmax = 255;
|
||||
*mMin = static_cast<float>((qmin - mean) / stdDev);
|
||||
*mMax = static_cast<float>((qmax - mean) / stdDev);
|
||||
}
|
||||
|
||||
void Conv2D::PopulaterQuantParam(const Primitive &prim,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam) {
|
||||
auto narrow_range = prim.GetAttr("narrow_range");
|
||||
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
|
||||
auto num_bits = prim.GetAttr("num_bits");
|
||||
int32_t numbitsRangeQuantParam = GetValue<int32_t>(num_bits);
|
||||
|
||||
std::vector<schema::QuantParamT> quants;
|
||||
schema::QuantParamT quantParam;
|
||||
auto mean = prim.GetAttr("mean");
|
||||
auto std_dev = prim.GetAttr("std_dev");
|
||||
if (mean != nullptr && std_dev != nullptr) {
|
||||
auto meanQuantOaram = GetValue<double>(mean);
|
||||
double stddevQuantOaram = GetValue<double>(std_dev);
|
||||
float mMin = 0.0;
|
||||
float mMax = 0.0;
|
||||
CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax);
|
||||
quantParam.min = mMin;
|
||||
quantParam.max = mMax;
|
||||
} else {
|
||||
auto inputMin = prim.GetAttr("input_minq");
|
||||
auto inputMax = prim.GetAttr("input_maxq");
|
||||
auto inputMinPtr = inputMin->cast<lite::tensor::TensorPtr>();
|
||||
auto inputMaxPtr = inputMax->cast<lite::tensor::TensorPtr>();
|
||||
float *minBuf = static_cast<float *>(inputMinPtr->Data());
|
||||
float *maxBuf = static_cast<float *>(inputMaxPtr->Data());
|
||||
quantParam.min = *minBuf;
|
||||
quantParam.max = *maxBuf;
|
||||
}
|
||||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
|
||||
numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
vecInputQuantParam->emplace_back(quants);
|
||||
|
||||
quants.clear();
|
||||
int biasQuantSize = 0;
|
||||
auto filterMin = prim.GetAttr("filter_minq");
|
||||
auto filterMax = prim.GetAttr("filter_maxq");
|
||||
if (filterMin != nullptr && filterMax != nullptr) {
|
||||
auto filterMinPtr = filterMin->cast<lite::tensor::TensorPtr>();
|
||||
auto filterMaxPtr = filterMax->cast<lite::tensor::TensorPtr>();
|
||||
float *minBuf = static_cast<float *>(filterMinPtr->Data());
|
||||
float *maxBuf = static_cast<float *>(filterMaxPtr->Data());
|
||||
biasQuantSize = filterMinPtr->DataSize();
|
||||
for (int i = 0; i < biasQuantSize; ++i) {
|
||||
quantParam.min = *(minBuf++);
|
||||
quantParam.max = *(maxBuf++);
|
||||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
|
||||
numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
}
|
||||
vecInputQuantParam->emplace_back(quants);
|
||||
}
|
||||
|
||||
quants.clear();
|
||||
for (int i = 0; i < biasQuantSize; ++i) {
|
||||
quantParam.min = 0.0;
|
||||
quantParam.max = 0.0;
|
||||
quantParam.zeroPoint = 0;
|
||||
|
||||
quantParam.scale = vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(i).scale;
|
||||
quants.emplace_back(quantParam);
|
||||
}
|
||||
vecInputQuantParam->emplace_back(quants);
|
||||
|
||||
quants.clear();
|
||||
auto outputMin = prim.GetAttr("output_minq");
|
||||
auto outputMax = prim.GetAttr("output_maxq");
|
||||
if (outputMin != nullptr && outputMax != nullptr) {
|
||||
auto outputMinPtr = outputMin->cast<lite::tensor::TensorPtr>();
|
||||
auto outputMaxPtr = outputMax->cast<lite::tensor::TensorPtr>();
|
||||
float *minBuf = static_cast<float *>(outputMinPtr->Data());
|
||||
float *maxBuf = static_cast<float *>(outputMaxPtr->Data());
|
||||
quantParam.min = *minBuf;
|
||||
quantParam.max = *maxBuf;
|
||||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
|
||||
numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
vecOutputQuantParam->emplace_back(quants);
|
||||
}
|
||||
}
|
||||
|
||||
int Conv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
this->primitive_ = new (schema::PrimitiveT);
|
||||
|
||||
int group = GetValue<int>(prim.GetAttr("group"));
|
||||
if (group > 1) {
|
||||
PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs);
|
||||
} else {
|
||||
PopulaterConv2DSingleGroup(prim, this->primitive_, group);
|
||||
}
|
||||
|
||||
if (GetQuantType() == schema::QuantType_AwareTraining) {
|
||||
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
|
||||
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
|
||||
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
|
||||
SetInputQuantParam(vecInputQuantParam);
|
||||
SetOutputQuantParam(vecOutputQuantParam);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
|
|
|
@ -20,17 +20,35 @@
|
|||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include <memory>
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Conv2D : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
||||
public:
|
||||
Conv2D() = default;
|
||||
explicit Conv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
|
||||
private:
|
||||
void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group,
|
||||
const std::vector<AnfNodePtr> &inputs);
|
||||
void PopulaterConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group);
|
||||
void PopulaterQuantParam(const Primitive &prim, std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam);
|
||||
void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax);
|
||||
#else
|
||||
|
||||
public:
|
||||
explicit Conv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
|
||||
public:
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int PadUp() const;
|
||||
int PadDown() const;
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class Conv2DGradFilter : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Conv2DGradFilter() = default;
|
||||
explicit Conv2DGradFilter(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Conv2DGradFilter(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int GetFormat() const;
|
||||
int GetGroup() const;
|
||||
int GetChannelIn() const;
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class Conv2DGradInput : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Conv2DGradInput() = default;
|
||||
explicit Conv2DGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Conv2DGradInput(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int GetFormat() const;
|
||||
int GetGroup() const;
|
||||
int GetChannelIn() const;
|
||||
|
|
|
@ -28,9 +28,11 @@ namespace lite {
|
|||
class Cos : public ArithmeticSelf {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Cos() = default;
|
||||
explicit Cos(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Cos(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class Crop : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Crop() = default;
|
||||
explicit Crop(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Crop(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
long GetAxis() const;
|
||||
std::vector<long> GetOffsets() const;
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class DeConv2D : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
DeConv2D() = default;
|
||||
explicit DeConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit DeConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetFormat() const;
|
||||
int GetGroup() const;
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class DeDepthwiseConv2D : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
DeDepthwiseConv2D() = default;
|
||||
explicit DeDepthwiseConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit DeDepthwiseConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetFormat() const;
|
||||
int GetChannelIn() const;
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class DepthToSpace : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
DepthToSpace() = default;
|
||||
explicit DepthToSpace(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit DepthToSpace(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetBlockSize() const;
|
||||
int GetFormat() const;
|
||||
|
|
|
@ -15,7 +15,11 @@
|
|||
*/
|
||||
|
||||
#include "src/ops/depthwise_conv2d.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#endif
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -65,6 +69,168 @@ void DepthwiseConv2D::SetActivationType(int activation_type) {
|
|||
this->primitive_->value.AsDepthwiseConv2D()->activationType = (schema::ActivationType)activation_type;
|
||||
}
|
||||
|
||||
void DepthwiseConv2D::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) {
|
||||
constexpr float qmin = 0;
|
||||
constexpr float qmax = 255;
|
||||
*mMin = static_cast<float>((qmin - mean) / stdDev);
|
||||
*mMax = static_cast<float>((qmax - mean) / stdDev);
|
||||
}
|
||||
|
||||
void DepthwiseConv2D::PopulaterQuantParam(const Primitive &prim,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam) {
|
||||
auto narrow_range = prim.GetAttr("narrow_range");
|
||||
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
|
||||
auto num_bits = prim.GetAttr("num_bits");
|
||||
int32_t numbitsRangeQuantParam = GetValue<int32_t>(num_bits);
|
||||
|
||||
std::vector<schema::QuantParamT> quants;
|
||||
schema::QuantParamT quantParam;
|
||||
auto mean = prim.GetAttr("mean");
|
||||
auto std_dev = prim.GetAttr("std_dev");
|
||||
if (mean != nullptr && std_dev != nullptr) {
|
||||
auto meanQuantOaram = GetValue<double>(mean);
|
||||
double stddevQuantOaram = GetValue<double>(std_dev);
|
||||
float mMin = 0.0;
|
||||
float mMax = 0.0;
|
||||
CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax);
|
||||
quantParam.min = mMin;
|
||||
quantParam.max = mMax;
|
||||
} else {
|
||||
auto inputMin = prim.GetAttr("input_minq");
|
||||
auto inputMax = prim.GetAttr("input_maxq");
|
||||
auto inputMinPtr = inputMin->cast<lite::tensor::TensorPtr>();
|
||||
auto inputMaxPtr = inputMax->cast<lite::tensor::TensorPtr>();
|
||||
float *minBuf = static_cast<float *>(inputMinPtr->Data());
|
||||
float *maxBuf = static_cast<float *>(inputMaxPtr->Data());
|
||||
quantParam.min = *minBuf;
|
||||
quantParam.max = *maxBuf;
|
||||
}
|
||||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
|
||||
numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
vecInputQuantParam->emplace_back(quants);
|
||||
|
||||
quants.clear();
|
||||
int biasQuantSize = 0;
|
||||
auto filterMin = prim.GetAttr("filter_minq");
|
||||
auto filterMax = prim.GetAttr("filter_maxq");
|
||||
if (filterMin != nullptr && filterMax != nullptr) {
|
||||
auto filterMinPtr = filterMin->cast<lite::tensor::TensorPtr>();
|
||||
auto filterMaxPtr = filterMax->cast<lite::tensor::TensorPtr>();
|
||||
float *minBuf = static_cast<float *>(filterMinPtr->Data());
|
||||
float *maxBuf = static_cast<float *>(filterMaxPtr->Data());
|
||||
biasQuantSize = filterMinPtr->DataSize();
|
||||
for (int i = 0; i < biasQuantSize; ++i) {
|
||||
quantParam.min = *(minBuf++);
|
||||
quantParam.max = *(maxBuf++);
|
||||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
|
||||
numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
}
|
||||
vecInputQuantParam->emplace_back(quants);
|
||||
}
|
||||
|
||||
quants.clear();
|
||||
for (int i = 0; i < biasQuantSize; ++i) {
|
||||
quantParam.min = 0.0;
|
||||
quantParam.max = 0.0;
|
||||
quantParam.zeroPoint = 0;
|
||||
|
||||
quantParam.scale = vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(i).scale;
|
||||
quants.emplace_back(quantParam);
|
||||
}
|
||||
vecInputQuantParam->emplace_back(quants);
|
||||
|
||||
quants.clear();
|
||||
auto outputMin = prim.GetAttr("output_minq");
|
||||
auto outputMax = prim.GetAttr("output_maxq");
|
||||
if (outputMin != nullptr && outputMax != nullptr) {
|
||||
auto outputMinPtr = outputMin->cast<lite::tensor::TensorPtr>();
|
||||
auto outputMaxPtr = outputMax->cast<lite::tensor::TensorPtr>();
|
||||
float *minBuf = static_cast<float *>(outputMinPtr->Data());
|
||||
float *maxBuf = static_cast<float *>(outputMaxPtr->Data());
|
||||
quantParam.min = *minBuf;
|
||||
quantParam.max = *maxBuf;
|
||||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
|
||||
numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
vecOutputQuantParam->emplace_back(quants);
|
||||
}
|
||||
}
|
||||
|
||||
int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
this->primitive_ = new (schema::PrimitiveT);
|
||||
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
|
||||
|
||||
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
|
||||
if (format == "NCHW") {
|
||||
attr->format = schema::Format_NCHW;
|
||||
} else if (format == "NHWC") {
|
||||
attr->format = schema::Format_NHWC;
|
||||
} else {
|
||||
attr->format = schema::Format_NUM_OF_FORMAT;
|
||||
}
|
||||
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pads"));
|
||||
attr->padUp = pad_list[0];
|
||||
attr->padDown = pad_list[1];
|
||||
attr->padLeft = pad_list[2];
|
||||
attr->padRight = pad_list[3];
|
||||
|
||||
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
|
||||
attr->dilateH = dilation[0];
|
||||
attr->dilateW = dilation[1];
|
||||
|
||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
|
||||
attr->kernelH = kernel_size[0];
|
||||
attr->kernelW = kernel_size[1];
|
||||
|
||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
|
||||
attr->strideH = stride[2];
|
||||
attr->strideW = stride[3];
|
||||
|
||||
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
||||
if (pad_mode == "valid") {
|
||||
attr->padMode = schema::PadMode_VALID;
|
||||
} else if (pad_mode == "same") {
|
||||
attr->padMode = schema::PadMode_SAME;
|
||||
} else {
|
||||
attr->padMode = schema::PadMode_NOTSET;
|
||||
}
|
||||
|
||||
auto channel_multiplier = GetValue<int>(prim.GetAttr("channel_multiplier"));
|
||||
attr->channelMultiplier = channel_multiplier;
|
||||
|
||||
MS_ASSERT(inputs.size() == kAnfPopulaterTwo);
|
||||
auto inputNode = inputs[kAnfPopulaterOne];
|
||||
MS_ASSERT(inputNode != nullptr);
|
||||
if (inputNode->isa<Parameter>()) {
|
||||
auto paramNode = inputNode->cast<ParameterPtr>();
|
||||
auto abstractBase = paramNode->abstract();
|
||||
MS_ASSERT(abstractBase != nullptr);
|
||||
if (utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
|
||||
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
|
||||
MS_ASSERT(abstractTensor != nullptr);
|
||||
if (utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) {
|
||||
auto dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
|
||||
attr->channelIn = dims[kAnfPopulaterOne];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this->primitive_->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
||||
this->primitive_->value.value = attr.release();
|
||||
|
||||
if (GetQuantType() == schema::QuantType_AwareTraining) {
|
||||
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
|
||||
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
|
||||
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
|
||||
SetInputQuantParam(vecInputQuantParam);
|
||||
SetOutputQuantParam(vecOutputQuantParam);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
int DepthwiseConv2D::GetFormat() const { return this->primitive_->value_as_DepthwiseConv2D()->format(); }
|
||||
|
|
|
@ -26,12 +26,25 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
class DepthwiseConv2D : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
explicit DepthwiseConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
explicit DepthwiseConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
public:
|
||||
DepthwiseConv2D() = default;
|
||||
explicit DepthwiseConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
|
||||
private:
|
||||
void PopulaterQuantParam(const Primitive &prim, std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam);
|
||||
void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax);
|
||||
#else
|
||||
|
||||
public:
|
||||
explicit DepthwiseConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
|
||||
public:
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetFormat() const;
|
||||
int GetChannelIn() const;
|
||||
|
|
|
@ -13,18 +13,20 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_ANF_MAKE_TUPLE_PARSER_H
|
||||
#define MINDSPORE_ANF_MAKE_TUPLE_PARSER_H
|
||||
#include "tools/anf_importer/anf_populater/anf_node_populater.h"
|
||||
#include "src/ops/dequant.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore::lite {
|
||||
class AnfMakeTuplePopulater : public AnfNodePopulater {
|
||||
public:
|
||||
AnfMakeTuplePopulater() = default;
|
||||
~AnfMakeTuplePopulater() override = default;
|
||||
int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr,
|
||||
const std::vector<AnfNodePtr> &inputs) override;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_ANF_MAKE_TUPLE_PARSER_H
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Dequant::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
this->primitive_ = new (schema::PrimitiveT);
|
||||
auto attr = std::make_unique<schema::OnnxInt8DequantizeT>();
|
||||
this->primitive_->value.type = schema::PrimitiveType_OnnxInt8Dequantize;
|
||||
this->primitive_->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_SRC_OPS_DEQUANT_H_
|
||||
#define LITE_MINDSPORE_LITE_SRC_OPS_DEQUANT_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Dequant : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Dequant() = default;
|
||||
explicit Dequant(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
#else
|
||||
explicit Dequant(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_SRC_OPS_DEQUANT_H_
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class DetectionPostProcess : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
DetectionPostProcess() = default;
|
||||
explicit DetectionPostProcess(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit DetectionPostProcess(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int GetFormat() const;
|
||||
int GetInputSize() const;
|
||||
float GetHScale() const;
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class Div : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Div() = default;
|
||||
explicit Div(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Div(schema::Primitive *primitive) : Arithmetic(primitive) {}
|
||||
|
||||
#endif
|
||||
int GetActivationType() const;
|
||||
void SetActivationType(int activation_type);
|
||||
};
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class Dropout : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Dropout() = default;
|
||||
explicit Dropout(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Dropout(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
float GetRatio() const;
|
||||
void SetRatio(float ratio);
|
||||
};
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class Eltwise : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Eltwise() = default;
|
||||
explicit Eltwise(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Eltwise(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int GetMode() const;
|
||||
void SetMode(int mode);
|
||||
};
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class Elu : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Elu() = default;
|
||||
explicit Elu(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Elu(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
float GetAlpha() const;
|
||||
void SetAlpha(float alpha);
|
||||
};
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class EmbeddingLookup : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
EmbeddingLookup() = default;
|
||||
explicit EmbeddingLookup(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit EmbeddingLookup(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
float GetMaxNorm() const;
|
||||
void SetMaxNorm(float max_norm);
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class EmbeddingLookupSparse : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
EmbeddingLookupSparse() = default;
|
||||
explicit EmbeddingLookupSparse(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit EmbeddingLookupSparse(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
std::vector<int> GetSpIds() const;
|
||||
std::vector<float> GetSpWeights() const;
|
||||
float GetMaxNortm() const;
|
||||
|
|
|
@ -28,9 +28,11 @@ namespace lite {
|
|||
class Equal : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Equal() = default;
|
||||
explicit Equal(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Equal(schema::Primitive *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,9 +28,11 @@ namespace lite {
|
|||
class Exp : public ArithmeticSelf {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Exp() = default;
|
||||
explicit Exp(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Exp(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class ExpandDims : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
ExpandDims() = default;
|
||||
explicit ExpandDims(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit ExpandDims(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetDim() const;
|
||||
void SetDim(int dim);
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class FakeQuantWithMinMaxVars : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
FakeQuantWithMinMaxVars() = default;
|
||||
explicit FakeQuantWithMinMaxVars(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit FakeQuantWithMinMaxVars(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
bool GetNarrowRange() const;
|
||||
int GetNumBits() const;
|
||||
void SetNarrowRange(bool narrow_range);
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class Fill : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Fill() = default;
|
||||
explicit Fill(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Fill(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
std::vector<int> GetDims() const;
|
||||
void SetDims(const std::vector<int> &dims);
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "src/ops/flatten.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -48,5 +49,14 @@ int Flatten::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
|
|||
output->set_shape(output_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Flatten::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
this->primitive_ = new (schema::PrimitiveT);
|
||||
auto attr = std::make_unique<schema::FlattenT>();
|
||||
this->primitive_->value.type = schema::PrimitiveType_Flatten;
|
||||
this->primitive_->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,11 +28,14 @@ namespace lite {
|
|||
class Flatten : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Flatten() = default;
|
||||
explicit Flatten(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Flatten(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,9 +28,11 @@ namespace lite {
|
|||
class Floor : public ArithmeticSelf {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Floor() = default;
|
||||
explicit Floor(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Floor(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,9 +28,11 @@ namespace lite {
|
|||
class FloorDiv : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
FloorDiv() = default;
|
||||
explicit FloorDiv(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit FloorDiv(schema::Primitive *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,9 +28,11 @@ namespace lite {
|
|||
class FloorMod : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
FloorMod() = default;
|
||||
explicit FloorMod(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit FloorMod(schema::Primitive *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class FullConnection : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
FullConnection() = default;
|
||||
explicit FullConnection(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit FullConnection(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
bool GetHasBias() const;
|
||||
int GetAxis() const;
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class FusedBatchNorm : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
FusedBatchNorm() = default;
|
||||
explicit FusedBatchNorm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit FusedBatchNorm(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
float GetEpsilon() const;
|
||||
float GetMomentum() const;
|
||||
int GetSpatial() const;
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class Gather : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Gather() = default;
|
||||
explicit Gather(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Gather(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetAxis() const;
|
||||
int GetBatchDims() const;
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class GatherNd : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
GatherNd() = default;
|
||||
explicit GatherNd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit GatherNd(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetBatchDims() const;
|
||||
void SetBatchDims(int batch_dims);
|
||||
|
|
|
@ -27,9 +27,11 @@ namespace lite {
|
|||
class Greater : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Greater() = default;
|
||||
explicit Greater(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Greater(schema::Primitive *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,9 +28,11 @@ namespace lite {
|
|||
class GreaterEqual : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
GreaterEqual() = default;
|
||||
explicit GreaterEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit GreaterEqual(schema::Primitive *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class L2Norm : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
L2Norm() = default;
|
||||
explicit L2Norm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit L2Norm(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
std::vector<int> GetAxis() const;
|
||||
float GetEpsilon() const;
|
||||
void SetAxis(const std::vector<int> &axis);
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class LeakyReLU : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
LeakyReLU() = default;
|
||||
explicit LeakyReLU(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit LeakyReLU(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
float GetNegativeSlope() const;
|
||||
void SetNegativeSlope(float negative_slope);
|
||||
};
|
||||
|
|
|
@ -28,9 +28,11 @@ namespace lite {
|
|||
class Less : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Less() = default;
|
||||
explicit Less(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Less(schema::Primitive *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,9 +28,11 @@ namespace lite {
|
|||
class LessEqual : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
LessEqual() = default;
|
||||
explicit LessEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit LessEqual(schema::Primitive *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class LocalResponseNormalization : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
LocalResponseNormalization() = default;
|
||||
explicit LocalResponseNormalization(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit LocalResponseNormalization(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int GetDepthRadius() const;
|
||||
float GetBias() const;
|
||||
float GetAlpha() const;
|
||||
|
|
|
@ -28,9 +28,11 @@ namespace lite {
|
|||
class Log : public ArithmeticSelf {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Log() = default;
|
||||
explicit Log(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Log(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,9 +28,11 @@ namespace lite {
|
|||
class LogicalAnd : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
LogicalAnd() = default;
|
||||
explicit LogicalAnd(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit LogicalAnd(schema::Primitive *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,9 +28,11 @@ namespace lite {
|
|||
class LogicalNot : public ArithmeticSelf {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
LogicalNot() = default;
|
||||
explicit LogicalNot(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit LogicalNot(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,9 +28,11 @@ namespace lite {
|
|||
class LogicalOr : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
LogicalOr() = default;
|
||||
explicit LogicalOr(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit LogicalOr(schema::Primitive *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class Lrn : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Lrn() = default;
|
||||
explicit Lrn(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Lrn(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
float GetAlpha() const;
|
||||
float GetBeta() const;
|
||||
float GetBias() const;
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class Lstm : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Lstm() = default;
|
||||
explicit Lstm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Lstm(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
bool GetBidirection() const;
|
||||
void SetBidirection(bool bidirection);
|
||||
|
|
|
@ -13,18 +13,21 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_ANF_QUANT_PARSER_H
|
||||
#define MINDSPORE_ANF_QUANT_PARSER_H
|
||||
#include "tools/anf_importer/anf_populater/anf_node_populater.h"
|
||||
#include <vector>
|
||||
namespace mindspore::lite {
|
||||
class AnfQuantPopulater : public AnfNodePopulater {
|
||||
public:
|
||||
AnfQuantPopulater() = default;
|
||||
~AnfQuantPopulater() override = default;
|
||||
int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr,
|
||||
const std::vector<AnfNodePtr> &inputs) override;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_ANF_QUANT_PARSER_H
|
||||
#include "src/ops/make_tuple.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int MakeTuple::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
this->primitive_ = new (schema::PrimitiveT);
|
||||
auto attr = std::make_unique<schema::MakeTupleT>();
|
||||
this->primitive_->value.type = schema::PrimitiveType_MakeTuple;
|
||||
this->primitive_->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_SRC_OPS_MAKE_TUPLE_H_
|
||||
#define LITE_MINDSPORE_LITE_SRC_OPS_MAKE_TUPLE_H_
|
||||
#include <vector>
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class MakeTuple : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MakeTuple() = default;
|
||||
explicit MakeTuple(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
#else
|
||||
explicit MakeTuple(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_SRC_OPS_MAKE_TUPLE_H_
|
|
@ -15,7 +15,11 @@
|
|||
*/
|
||||
|
||||
#include "src/ops/matmul.h"
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -26,6 +30,102 @@ bool MatMul::GetTransposeB() const { return this->primitive_->value.AsMatMul()->
|
|||
void MatMul::SetTransposeA(bool transpose_a) { this->primitive_->value.AsMatMul()->transposeA = transpose_a; }
|
||||
void MatMul::SetTransposeB(bool transpose_b) { this->primitive_->value.AsMatMul()->transposeB = transpose_b; }
|
||||
|
||||
void MatMul::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) {
|
||||
constexpr float qmin = 0;
|
||||
constexpr float qmax = 255;
|
||||
*mMin = static_cast<float>((qmin - mean) / stdDev);
|
||||
*mMax = static_cast<float>((qmax - mean) / stdDev);
|
||||
}
|
||||
|
||||
void MatMul::PopulaterQuantParam(const Primitive &prim,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam) {
|
||||
auto narrow_range = prim.GetAttr("narrow_range");
|
||||
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
|
||||
auto num_bits = prim.GetAttr("num_bits");
|
||||
int32_t numbitsRangeQuantParam = GetValue<int32_t>(num_bits);
|
||||
|
||||
std::vector<schema::QuantParamT> quants;
|
||||
schema::QuantParamT quantParam;
|
||||
auto mean = prim.GetAttr("mean");
|
||||
auto std_dev = prim.GetAttr("std_dev");
|
||||
if (mean != nullptr && std_dev != nullptr) {
|
||||
auto meanQuantOaram = GetValue<double>(mean);
|
||||
double stddevQuantOaram = GetValue<double>(std_dev);
|
||||
float mMin = 0.0;
|
||||
float mMax = 0.0;
|
||||
CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax);
|
||||
quantParam.min = mMin;
|
||||
quantParam.max = mMax;
|
||||
} else {
|
||||
auto inputMin = prim.GetAttr("input_minq");
|
||||
auto inputMax = prim.GetAttr("input_maxq");
|
||||
auto inputMinPtr = inputMin->cast<lite::tensor::TensorPtr>();
|
||||
auto inputMaxPtr = inputMax->cast<lite::tensor::TensorPtr>();
|
||||
float *minBuf = static_cast<float *>(inputMinPtr->Data());
|
||||
float *maxBuf = static_cast<float *>(inputMaxPtr->Data());
|
||||
quantParam.min = *minBuf;
|
||||
quantParam.max = *maxBuf;
|
||||
}
|
||||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
|
||||
numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
vecInputQuantParam->emplace_back(quants);
|
||||
|
||||
quants.clear();
|
||||
auto filterMin = prim.GetAttr("filter_minq");
|
||||
auto filterMax = prim.GetAttr("filter_maxq");
|
||||
if (filterMin != nullptr && filterMax != nullptr) {
|
||||
auto filterMinPtr = filterMin->cast<lite::tensor::TensorPtr>();
|
||||
auto filterMaxPtr = filterMax->cast<lite::tensor::TensorPtr>();
|
||||
float *minBuf = static_cast<float *>(filterMinPtr->Data());
|
||||
float *maxBuf = static_cast<float *>(filterMaxPtr->Data());
|
||||
for (int i = 0; i < filterMinPtr->DataSize(); ++i) {
|
||||
quantParam.min = *(minBuf++);
|
||||
quantParam.max = *(maxBuf++);
|
||||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
|
||||
numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
}
|
||||
vecInputQuantParam->emplace_back(quants);
|
||||
}
|
||||
|
||||
quants.clear();
|
||||
auto outputMin = prim.GetAttr("output_minq");
|
||||
auto outputMax = prim.GetAttr("output_maxq");
|
||||
if (outputMin != nullptr && outputMax != nullptr) {
|
||||
auto outputMinPtr = outputMin->cast<lite::tensor::TensorPtr>();
|
||||
auto outputMaxPtr = outputMax->cast<lite::tensor::TensorPtr>();
|
||||
float *minBuf = static_cast<float *>(outputMinPtr->Data());
|
||||
float *maxBuf = static_cast<float *>(outputMaxPtr->Data());
|
||||
quantParam.min = *minBuf;
|
||||
quantParam.max = *maxBuf;
|
||||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
|
||||
numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
vecOutputQuantParam->emplace_back(quants);
|
||||
}
|
||||
}
|
||||
|
||||
int MatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
this->primitive_ = new (schema::PrimitiveT);
|
||||
|
||||
auto attr = std::make_unique<schema::MatMulT>();
|
||||
attr->transposeA = GetValue<bool>(prim.GetAttr("transpose_a"));
|
||||
attr->transposeB = GetValue<bool>(prim.GetAttr("transpose_b"));
|
||||
|
||||
this->primitive_->value.type = schema::PrimitiveType_MatMul;
|
||||
this->primitive_->value.value = attr.release();
|
||||
if (GetQuantType() == schema::QuantType_AwareTraining) {
|
||||
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
|
||||
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
|
||||
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
|
||||
SetInputQuantParam(vecInputQuantParam);
|
||||
SetOutputQuantParam(vecOutputQuantParam);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
bool MatMul::GetTransposeA() const { return this->primitive_->value_as_MatMul()->transposeA(); }
|
||||
|
|
|
@ -20,18 +20,29 @@
|
|||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class MatMul : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
public:
|
||||
MatMul() = default;
|
||||
explicit MatMul(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
explicit MatMul(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
|
||||
private:
|
||||
void PopulaterQuantParam(const Primitive &prim, std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam);
|
||||
void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax);
|
||||
#else
|
||||
|
||||
public:
|
||||
explicit MatMul(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
|
||||
public:
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
bool GetTransposeA() const;
|
||||
bool GetTransposeB() const;
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class MatrixDiag : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MatrixDiag() = default;
|
||||
explicit MatrixDiag(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit MatrixDiag(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int GetK() const;
|
||||
int GetNumRows() const;
|
||||
int GetNumCols() const;
|
||||
|
|
|
@ -28,9 +28,11 @@ namespace lite {
|
|||
class Maximum : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Maximum() = default;
|
||||
explicit Maximum(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Maximum(schema::Primitive *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class Mean : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Mean() = default;
|
||||
explicit Mean(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Mean(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
std::vector<int> GetAxis() const;
|
||||
bool GetKeepDims() const;
|
||||
|
|
|
@ -28,9 +28,11 @@ namespace lite {
|
|||
class Minimum : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Minimum() = default;
|
||||
explicit Minimum(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Minimum(schema::Primitive *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "src/ops/mul.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -24,6 +25,14 @@ int Mul::GetActivationType() const { return this->primitive_->value.AsMul()->act
|
|||
void Mul::SetActivationType(int activation_type) {
|
||||
this->primitive_->value.AsMul()->activationType = (schema::ActivationType)activation_type;
|
||||
}
|
||||
int Mul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
this->primitive_ = new (schema::PrimitiveT);
|
||||
auto attr = std::make_unique<schema::MulT>();
|
||||
this->primitive_->value.type = schema::PrimitiveType_Mul;
|
||||
this->primitive_->value.value = attr.release();
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
|
|
|
@ -28,12 +28,15 @@ namespace lite {
|
|||
class Mul : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Mul() = default;
|
||||
explicit Mul(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Mul(schema::Primitive *primitive) : Arithmetic(primitive) {}
|
||||
|
||||
#endif
|
||||
int GetActivationType() const;
|
||||
void SetActivationType(int activation_type);
|
||||
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class Nchw2Nhwc : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Nchw2Nhwc() = default;
|
||||
explicit Nchw2Nhwc(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Nchw2Nhwc(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class Nhwc2Nchw : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Nhwc2Nchw() = default;
|
||||
explicit Nhwc2Nchw(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Nhwc2Nchw(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
|
|
|
@ -28,9 +28,11 @@ namespace lite {
|
|||
class NotEqual : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
NotEqual() = default;
|
||||
explicit NotEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit NotEqual(schema::Primitive *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class OneHot : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
OneHot() = default;
|
||||
explicit OneHot(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit OneHot(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetAxis() const;
|
||||
void SetAxis(int axis);
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class Pad : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Pad() = default;
|
||||
explicit Pad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Pad(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
std::vector<int> GetPaddings() const;
|
||||
int GetPaddingMode() const;
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class Permute : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Permute() = default;
|
||||
explicit Permute(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Permute(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
std::vector<long> GetOrder() const;
|
||||
void SetOrder(const std::vector<long> &order);
|
||||
};
|
||||
|
|
|
@ -15,6 +15,9 @@
|
|||
*/
|
||||
|
||||
#include "src/ops/pooling.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -52,6 +55,47 @@ void Pooling::SetRoundMode(int round_mode) {
|
|||
this->primitive_->value.AsPooling()->roundMode = (schema::RoundMode)round_mode;
|
||||
}
|
||||
|
||||
int Pooling::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
this->primitive_ = new (schema::PrimitiveT);
|
||||
auto attr = std::make_unique<schema::PoolingT>();
|
||||
if (prim.instance_name() == "MaxPool") {
|
||||
attr->poolingMode = schema::PoolMode_MAX_POOLING;
|
||||
} else if (prim.instance_name() == "MeanPool") {
|
||||
attr->poolingMode = schema::PoolMode_MEAN_POOLING;
|
||||
}
|
||||
|
||||
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
|
||||
if (format == "NCHW") {
|
||||
attr->format = schema::Format_NCHW;
|
||||
} else if (format == "NHWC") {
|
||||
attr->format = schema::Format_NHWC;
|
||||
} else {
|
||||
attr->format = schema::Format_NUM_OF_FORMAT;
|
||||
}
|
||||
|
||||
auto pad_mode = GetValue<std::string>(prim.GetAttr("padding"));
|
||||
if (pad_mode == "VALID") {
|
||||
attr->padMode = schema::PadMode_VALID;
|
||||
} else if (pad_mode == "SAME") {
|
||||
attr->padMode = schema::PadMode_SAME;
|
||||
} else {
|
||||
attr->padMode = schema::PadMode_NOTSET;
|
||||
}
|
||||
|
||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("ksize"));
|
||||
attr->windowH = kernel_size[2];
|
||||
attr->windowW = kernel_size[3];
|
||||
|
||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("strides"));
|
||||
attr->strideH = stride[2];
|
||||
attr->strideW = stride[3];
|
||||
|
||||
this->primitive_->value.type = schema::PrimitiveType_Pooling;
|
||||
this->primitive_->value.value = attr.release();
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
int Pooling::GetFormat() const { return this->primitive_->value_as_Pooling()->format(); }
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class Pooling : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Pooling() = default;
|
||||
explicit Pooling(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Pooling(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetFormat() const;
|
||||
int GetPoolingMode() const;
|
||||
|
@ -65,6 +66,8 @@ class Pooling : public PrimitiveC {
|
|||
int PadLeft() const;
|
||||
int PadRight() const;
|
||||
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
|
||||
protected:
|
||||
int pad_u_ = 0;
|
||||
int pad_d_ = 0;
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class PoolingGrad : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
PoolingGrad() = default;
|
||||
explicit PoolingGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit PoolingGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int GetFormat() const;
|
||||
int GetPoolingMode() const;
|
||||
bool GetGlobal() const;
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class Power : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
Power() = default;
|
||||
explicit Power(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit Power(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
float GetPower() const;
|
||||
float GetScale() const;
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace lite {
|
|||
class PowerGrad : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
PowerGrad() = default;
|
||||
explicit PowerGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
#else
|
||||
explicit PowerGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
#endif
|
||||
float GetPower() const;
|
||||
float GetScale() const;
|
||||
float GetShift() const;
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue