!5862 [MSLITE] grad ops added
Merge pull request !5862 from wangchangkai/master
This commit is contained in:
commit
8b007f24a9
|
@ -179,6 +179,7 @@ union PrimitiveType {
|
|||
Conv2DGradInput,
|
||||
PoolingGrad,
|
||||
BNGrad,
|
||||
BNGradInput,
|
||||
ApplyMomentum,
|
||||
BiasGrad,
|
||||
SoftmaxCrossEntropy,
|
||||
|
|
|
@ -398,7 +398,10 @@ table BNGrad {
|
|||
eps : float;
|
||||
momentum: float;
|
||||
}
|
||||
|
||||
table BNGradInput {
|
||||
eps : float;
|
||||
momentum: float;
|
||||
}
|
||||
table Scale {
|
||||
axis: int;
|
||||
}
|
||||
|
|
|
@ -25,6 +25,36 @@ void ActivationGrad::SetType(int type) {
|
|||
this->primitive_->value.AsActivationGrad()->type = (schema::ActivationType)type;
|
||||
}
|
||||
void ActivationGrad::SetAlpha(float alpha) { this->primitive_->value.AsActivationGrad()->alpha = alpha; }
|
||||
int ActivationGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_ActivationGrad;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_ActivationGrad) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto attr = std::make_unique<schema::ActivationGradT>();
|
||||
if (prim.name() == "ReLU") {
|
||||
attr->type = schema::ActivationType_RELU;
|
||||
} else if (prim.name() == "Sigmoid") {
|
||||
attr->type = schema::ActivationType_SIGMOID;
|
||||
} else if (prim.name() == "ReLU6") {
|
||||
attr->type = schema::ActivationType_RELU6;
|
||||
}
|
||||
auto alpha = GetValue<float>(prim.GetAttr("alpha"));
|
||||
attr->alpha = alpha;
|
||||
this->primitive_->value.value = attr.release();
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
#else
|
||||
int ActivationGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
|
@ -33,6 +34,7 @@ class ActivationGrad : public PrimitiveC {
|
|||
explicit ActivationGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetType(int type);
|
||||
void SetAlpha(float alpha);
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
#else
|
||||
ActivationGrad() = default;
|
||||
|
||||
|
|
|
@ -22,7 +22,34 @@ namespace lite {
|
|||
std::vector<int> BiasGrad::GetAxis() const { return this->primitive_->value.AsBiasGrad()->axis; }
|
||||
|
||||
void BiasGrad::SetAxis(const std::vector<int> &axis) { this->primitive_->value.AsBiasGrad()->axis = axis; }
|
||||
|
||||
int BiasGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_BiasGrad;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_BiasGrad) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
auto attr = new (std::nothrow) schema::BiasGradT();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis"));
|
||||
this->primitive_->value.value = attr;
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive value is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
#else
|
||||
int BiasGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
|
|
|
@ -33,7 +33,7 @@ class BiasGrad : public PrimitiveC {
|
|||
BiasGrad() = default;
|
||||
explicit BiasGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetAxis(const std::vector<int> &axis);
|
||||
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
#else
|
||||
BiasGrad() = default;
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_H_
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/bn_grad_input.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
float BNGradInput::GetEps() const { return this->primitive_->value.AsBNGradInput()->eps; }
|
||||
float BNGradInput::GetMomentum() const { return this->primitive_->value.AsBNGradInput()->momentum; }
|
||||
|
||||
void BNGradInput::SetEps(float eps) { this->primitive_->value.AsBNGradInput()->eps = eps; }
|
||||
void BNGradInput::SetMomentum(float momentum) { this->primitive_->value.AsBNGradInput()->momentum = momentum; }
|
||||
int BNGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_BNGradInput;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_BNGradInput) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
auto attr = new (std::nothrow) schema::BNGradInputT();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->eps = GetValue<float>(prim.GetAttr("eps"));
|
||||
attr->momentum = GetValue<float>(prim.GetAttr("momentum"));
|
||||
this->primitive_->value.value = attr;
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive value is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
#else
|
||||
int BNGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_BNGradInput();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_BNGradInputInput return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateBNGradInput(*fbb, attr->eps(), attr->momentum());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BNGradInput, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
float BNGradInput::GetEps() const { return this->primitive_->value_as_BNGradInput()->eps(); }
|
||||
float BNGradInput::GetMomentum() const { return this->primitive_->value_as_BNGradInput()->momentum(); }
|
||||
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class BNGradInput : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(BNGradInput, PrimitiveC);
|
||||
BNGradInput() = default;
|
||||
explicit BNGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetEps(float eps);
|
||||
void SetMomentum(float momentum);
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
#else
|
||||
BNGradInput() = default;
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
float GetEps() const;
|
||||
float GetMomentum() const;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_
|
|
@ -66,7 +66,133 @@ void Conv2DGradFilter::SetHasBias(bool has_bias) { this->primitive_->value.AsCon
|
|||
void Conv2DGradFilter::SetActivationType(int activation_type) {
|
||||
this->primitive_->value.AsConv2DGradFilter()->activationType = (schema::ActivationType)activation_type;
|
||||
}
|
||||
void Conv2DGradFilter::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group,
|
||||
const std::vector<AnfNodePtr> &inputs) {
|
||||
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
|
||||
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
|
||||
if (format == "NCHW") {
|
||||
attr->format = schema::Format_NCHW;
|
||||
} else if (format == "NHWC") {
|
||||
attr->format = schema::Format_NHWC;
|
||||
} else {
|
||||
attr->format = schema::Format_NUM_OF_FORMAT;
|
||||
}
|
||||
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
|
||||
attr->padUp = pad_list[0];
|
||||
attr->padDown = pad_list[1];
|
||||
attr->padLeft = pad_list[2];
|
||||
attr->padRight = pad_list[3];
|
||||
|
||||
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
|
||||
attr->dilateH = dilation[0];
|
||||
attr->dilateW = dilation[1];
|
||||
|
||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
|
||||
attr->kernelH = kernel_size[0];
|
||||
attr->kernelW = kernel_size[1];
|
||||
|
||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
|
||||
attr->strideH = stride[2];
|
||||
attr->strideW = stride[3];
|
||||
|
||||
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
||||
if (pad_mode == "valid") {
|
||||
attr->padMode = schema::PadMode_VALID;
|
||||
} else if (pad_mode == "same") {
|
||||
attr->padMode = schema::PadMode_SAME;
|
||||
} else {
|
||||
attr->padMode = schema::PadMode_NOTSET;
|
||||
}
|
||||
|
||||
if (prim.GetAttr("activation_name") != nullptr) {
|
||||
std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name"));
|
||||
attr->activationType = kActivationTypeMap[activate_name];
|
||||
} else {
|
||||
attr->activationType = schema::ActivationType_NO_ACTIVATION;
|
||||
}
|
||||
|
||||
int channel_mutiplier = 1;
|
||||
if (prim.GetAttr("channel_mutiplier") != nullptr) {
|
||||
channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier"));
|
||||
}
|
||||
attr->channelMultiplier = channel_mutiplier;
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
||||
primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
void Conv2DGradFilter::PopulaterConv2DSingleGroup(const Primitive &prim,
|
||||
schema::PrimitiveT *primitive, const int &group) {
|
||||
auto attr = std::make_unique<schema::Conv2DT>();
|
||||
attr->group = group;
|
||||
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
|
||||
if (format == "NCHW") {
|
||||
attr->format = schema::Format_NCHW;
|
||||
} else if (format == "NHWC") {
|
||||
attr->format = schema::Format_NHWC;
|
||||
} else {
|
||||
attr->format = schema::Format_NUM_OF_FORMAT;
|
||||
}
|
||||
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
|
||||
attr->padUp = pad_list[0];
|
||||
attr->padDown = pad_list[1];
|
||||
attr->padLeft = pad_list[2];
|
||||
attr->padRight = pad_list[3];
|
||||
|
||||
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
|
||||
attr->dilateH = dilation[0];
|
||||
attr->dilateW = dilation[1];
|
||||
|
||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
|
||||
attr->kernelH = kernel_size[0];
|
||||
attr->kernelW = kernel_size[1];
|
||||
|
||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
|
||||
attr->strideH = stride[2];
|
||||
attr->strideW = stride[3];
|
||||
|
||||
attr->channelOut = GetValue<int>(prim.GetAttr("out_channel"));
|
||||
|
||||
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
||||
if (pad_mode == "valid") {
|
||||
attr->padMode = schema::PadMode_VALID;
|
||||
} else if (pad_mode == "same") {
|
||||
attr->padMode = schema::PadMode_SAME;
|
||||
} else {
|
||||
attr->padMode = schema::PadMode_NOTSET;
|
||||
}
|
||||
|
||||
if (prim.GetAttr("activation_name") != nullptr) {
|
||||
std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name"));
|
||||
attr->activationType = kActivationTypeMap[activate_name];
|
||||
} else {
|
||||
attr->activationType = schema::ActivationType_NO_ACTIVATION;
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||
primitive->value.value = attr.release();
|
||||
}
|
||||
int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_Conv2DGradFilter;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_Conv2DGradFilter) {
|
||||
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
int group = GetValue<int>(prim.GetAttr("group"));
|
||||
if (group > 1) {
|
||||
PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs);
|
||||
} else {
|
||||
PopulaterConv2DSingleGroup(prim, this->primitive_, group);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
#else
|
||||
int Conv2DGradFilter::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
|
|
|
@ -20,6 +20,8 @@
|
|||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
|
@ -48,6 +50,10 @@ class Conv2DGradFilter : public PrimitiveC {
|
|||
void SetDilateH(int dilate_h);
|
||||
void SetHasBias(bool has_bias);
|
||||
void SetActivationType(int activation_type);
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group,
|
||||
const std::vector<AnfNodePtr> &inputs);
|
||||
void PopulaterConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group);
|
||||
#else
|
||||
Conv2DGradFilter() = default;
|
||||
|
||||
|
|
|
@ -64,7 +64,133 @@ void Conv2DGradInput::SetHasBias(bool has_bias) { this->primitive_->value.AsConv
|
|||
void Conv2DGradInput::SetActivationType(int activation_type) {
|
||||
this->primitive_->value.AsConv2DGradInput()->activationType = (schema::ActivationType)activation_type;
|
||||
}
|
||||
void Conv2DGradInput::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group,
|
||||
const std::vector<AnfNodePtr> &inputs) {
|
||||
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
|
||||
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
|
||||
if (format == "NCHW") {
|
||||
attr->format = schema::Format_NCHW;
|
||||
} else if (format == "NHWC") {
|
||||
attr->format = schema::Format_NHWC;
|
||||
} else {
|
||||
attr->format = schema::Format_NUM_OF_FORMAT;
|
||||
}
|
||||
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
|
||||
attr->padUp = pad_list[0];
|
||||
attr->padDown = pad_list[1];
|
||||
attr->padLeft = pad_list[2];
|
||||
attr->padRight = pad_list[3];
|
||||
|
||||
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
|
||||
attr->dilateH = dilation[0];
|
||||
attr->dilateW = dilation[1];
|
||||
|
||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
|
||||
attr->kernelH = kernel_size[0];
|
||||
attr->kernelW = kernel_size[1];
|
||||
|
||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
|
||||
attr->strideH = stride[2];
|
||||
attr->strideW = stride[3];
|
||||
|
||||
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
||||
if (pad_mode == "valid") {
|
||||
attr->padMode = schema::PadMode_VALID;
|
||||
} else if (pad_mode == "same") {
|
||||
attr->padMode = schema::PadMode_SAME;
|
||||
} else {
|
||||
attr->padMode = schema::PadMode_NOTSET;
|
||||
}
|
||||
|
||||
if (prim.GetAttr("activation_name") != nullptr) {
|
||||
std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name"));
|
||||
attr->activationType = kActivationTypeMap[activate_name];
|
||||
} else {
|
||||
attr->activationType = schema::ActivationType_NO_ACTIVATION;
|
||||
}
|
||||
|
||||
int channel_mutiplier = 1;
|
||||
if (prim.GetAttr("channel_mutiplier") != nullptr) {
|
||||
channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier"));
|
||||
}
|
||||
attr->channelMultiplier = channel_mutiplier;
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
||||
primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
void Conv2DGradInput::PopulaterConv2DSingleGroup(const Primitive &prim,
|
||||
schema::PrimitiveT *primitive, const int &group) {
|
||||
auto attr = std::make_unique<schema::Conv2DT>();
|
||||
attr->group = group;
|
||||
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
|
||||
if (format == "NCHW") {
|
||||
attr->format = schema::Format_NCHW;
|
||||
} else if (format == "NHWC") {
|
||||
attr->format = schema::Format_NHWC;
|
||||
} else {
|
||||
attr->format = schema::Format_NUM_OF_FORMAT;
|
||||
}
|
||||
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
|
||||
attr->padUp = pad_list[0];
|
||||
attr->padDown = pad_list[1];
|
||||
attr->padLeft = pad_list[2];
|
||||
attr->padRight = pad_list[3];
|
||||
|
||||
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
|
||||
attr->dilateH = dilation[0];
|
||||
attr->dilateW = dilation[1];
|
||||
|
||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
|
||||
attr->kernelH = kernel_size[0];
|
||||
attr->kernelW = kernel_size[1];
|
||||
|
||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
|
||||
attr->strideH = stride[2];
|
||||
attr->strideW = stride[3];
|
||||
|
||||
attr->channelOut = GetValue<int>(prim.GetAttr("out_channel"));
|
||||
|
||||
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
||||
if (pad_mode == "valid") {
|
||||
attr->padMode = schema::PadMode_VALID;
|
||||
} else if (pad_mode == "same") {
|
||||
attr->padMode = schema::PadMode_SAME;
|
||||
} else {
|
||||
attr->padMode = schema::PadMode_NOTSET;
|
||||
}
|
||||
|
||||
if (prim.GetAttr("activation_name") != nullptr) {
|
||||
std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name"));
|
||||
attr->activationType = kActivationTypeMap[activate_name];
|
||||
} else {
|
||||
attr->activationType = schema::ActivationType_NO_ACTIVATION;
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||
primitive->value.value = attr.release();
|
||||
}
|
||||
int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_Conv2DGradInput;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_Conv2DGradInput) {
|
||||
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
int group = GetValue<int>(prim.GetAttr("group"));
|
||||
if (group > 1) {
|
||||
PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs);
|
||||
} else {
|
||||
PopulaterConv2DSingleGroup(prim, this->primitive_, group);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
#else
|
||||
int Conv2DGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
|
|
|
@ -20,6 +20,8 @@
|
|||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
|
@ -48,6 +50,10 @@ class Conv2DGradInput : public PrimitiveC {
|
|||
void SetDilateH(int dilate_h);
|
||||
void SetHasBias(bool has_bias);
|
||||
void SetActivationType(int activation_type);
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group,
|
||||
const std::vector<AnfNodePtr> &inputs);
|
||||
void PopulaterConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group);
|
||||
#else
|
||||
Conv2DGradInput() = default;
|
||||
|
||||
|
|
|
@ -52,7 +52,64 @@ void PoolingGrad::SetPadRight(int pad_right) { this->primitive_->value.AsPooling
|
|||
void PoolingGrad::SetRoundMode(int round_mode) {
|
||||
this->primitive_->value.AsPoolingGrad()->roundMode = (schema::RoundMode)round_mode;
|
||||
}
|
||||
int PoolingGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_PoolingGrad;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_PoolingGrad) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
auto attr = new (std::nothrow) schema::PoolingGradT();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
|
||||
if (format == "NCHW") {
|
||||
attr->format = schema::Format_NCHW;
|
||||
} else if (format == "NHWC") {
|
||||
attr->format = schema::Format_NHWC;
|
||||
} else {
|
||||
attr->format = schema::Format_NUM_OF_FORMAT;
|
||||
}
|
||||
if (prim.instance_name() == "MaxPool") {
|
||||
attr->poolingMode = schema::PoolMode_MAX_POOLING;
|
||||
} else if (prim.instance_name() == "MeanPool") {
|
||||
attr->poolingMode = schema::PoolMode_MEAN_POOLING;
|
||||
}
|
||||
|
||||
auto pad_mode = GetValue<std::string>(prim.GetAttr("padding"));
|
||||
if (pad_mode == "VALID") {
|
||||
attr->padMode = schema::PadMode_VALID;
|
||||
} else if (pad_mode == "SAME") {
|
||||
attr->padMode = schema::PadMode_SAME;
|
||||
} else {
|
||||
attr->padMode = schema::PadMode_NOTSET;
|
||||
}
|
||||
|
||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("ksize"));
|
||||
attr->windowH = kernel_size[2];
|
||||
attr->windowW = kernel_size[3];
|
||||
|
||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("strides"));
|
||||
attr->strideH = stride[2];
|
||||
attr->strideW = stride[3];
|
||||
this->primitive_->value.value = attr;
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive value is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
#else
|
||||
|
||||
int PoolingGrad::GetFormat() const { return this->primitive_->value_as_PoolingGrad()->format(); }
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
|
@ -44,6 +45,7 @@ class PoolingGrad : public PrimitiveC {
|
|||
void SetPadLeft(int pad_left);
|
||||
void SetPadRight(int pad_right);
|
||||
void SetRoundMode(int round_mode);
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
#else
|
||||
PoolingGrad() = default;
|
||||
|
||||
|
|
|
@ -26,7 +26,36 @@ float PowerGrad::GetShift() const { return this->primitive_->value.AsPowerGrad()
|
|||
void PowerGrad::SetPower(float power) { this->primitive_->value.AsPowerGrad()->power = power; }
|
||||
void PowerGrad::SetScale(float scale) { this->primitive_->value.AsPowerGrad()->scale = scale; }
|
||||
void PowerGrad::SetShift(float shift) { this->primitive_->value.AsPowerGrad()->shift = shift; }
|
||||
|
||||
int PowerGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_PowerGrad;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_PowerGrad) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
auto attr = new (std::nothrow) schema::PowerGradT();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->power = GetValue<float>(prim.GetAttr("power"));
|
||||
attr->scale = GetValue<float>(prim.GetAttr("scale"));
|
||||
attr->shift = GetValue<float>(prim.GetAttr("shift"));
|
||||
this->primitive_->value.value = attr;
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive value is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
#else
|
||||
|
||||
float PowerGrad::GetPower() const { return this->primitive_->value_as_PowerGrad()->power(); }
|
||||
|
|
|
@ -34,6 +34,7 @@ class PowerGrad : public PrimitiveC {
|
|||
void SetPower(float power);
|
||||
void SetScale(float scale);
|
||||
void SetShift(float shift);
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
#else
|
||||
PowerGrad() = default;
|
||||
|
||||
|
|
|
@ -383,6 +383,20 @@ std::shared_ptr<PrimitiveC> PrimitiveC::UnPackFromPrimitive(const Primitive &pri
|
|||
return NewPrimitiveC<ApplyMomentum>(prim, inputs, quantType);
|
||||
} else if (op_type == "BatchNormGrad") {
|
||||
return NewPrimitiveC<BNGrad>(prim, inputs, quantType);
|
||||
} else if (op_type == "Conv2DGradInput") {
|
||||
return NewPrimitiveC<Conv2DGradInput>(prim, inputs, quantType);
|
||||
} else if (op_type == "Conv2DGradFilter") {
|
||||
return NewPrimitiveC<Conv2DGradFilter>(prim, inputs, quantType);
|
||||
} else if (op_type == "BiasGrad") {
|
||||
return NewPrimitiveC<BiasGrad>(prim, inputs, quantType);
|
||||
} else if (op_type == "ActivationGrad") {
|
||||
return NewPrimitiveC<ActivationGrad>(prim, inputs, quantType);
|
||||
} else if (op_type == "PoolingGrad") {
|
||||
return NewPrimitiveC<PoolingGrad>(prim, inputs, quantType);
|
||||
} else if (op_type == "BNGradInput") {
|
||||
return NewPrimitiveC<BNGradInput>(prim, inputs, quantType);
|
||||
} else if (op_type == "PowerGrad") {
|
||||
return NewPrimitiveC<PowerGrad>(prim, inputs, quantType);
|
||||
#endif
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromPrimitive : " << op_type;
|
||||
|
@ -620,6 +634,10 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT
|
|||
return new ArithmeticGrad(primitive);
|
||||
case schema::PrimitiveType_DivGrad:
|
||||
return new ArithmeticGrad(primitive);
|
||||
case schema::PrimitiveType_PowerGrad:
|
||||
return new PowerGrad(primitive);
|
||||
case schema::PrimitiveType_BNGradInput:
|
||||
return new BNGradInput(primitive);
|
||||
#endif
|
||||
|
||||
default:
|
||||
|
|
Loading…
Reference in New Issue