!9837 [MSLITE][Develop] fix code review

From: @sunsuodong
Reviewed-by: @zhanghaibo5,@zhang_xue_tong
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2020-12-14 09:57:25 +08:00 committed by Gitee
commit 920d34db0b
19 changed files with 12 additions and 30 deletions

View File

@ -781,7 +781,7 @@ table Reduce {
table Transpose {
perm: [int];
conjugate: bool = false;
conjugate: bool = false; // DEPRECATED
}
table Squeeze {

View File

@ -39,7 +39,6 @@ OpParameter *PopulateTransposeParameter(const mindspore::lite::PrimitiveC *primi
transpose_param->perm_[i++] = *iter;
}
transpose_param->num_axes_ = i;
transpose_param->conjugate_ = param->GetConjugate();
return reinterpret_cast<OpParameter *>(transpose_param);
}

View File

@ -27,10 +27,7 @@ namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> Transpose::GetPerm() const { return this->primitive_->value.AsTranspose()->perm; }
bool Transpose::GetConjugate() const { return this->primitive_->value.AsTranspose()->conjugate; }
void Transpose::SetPerm(const std::vector<int> &perm) { this->primitive_->value.AsTranspose()->perm = perm; }
void Transpose::SetConjugate(bool conjugate) { this->primitive_->value.AsTranspose()->conjugate = conjugate; }
int Transpose::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
@ -83,7 +80,6 @@ std::vector<int> Transpose::GetPerm() const {
auto fb_vector = this->primitive_->value_as_Transpose()->perm();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
bool Transpose::GetConjugate() const { return this->primitive_->value_as_Transpose()->conjugate(); }
int Transpose::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
@ -127,11 +123,6 @@ int Transpose::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> o
MS_ASSERT(inputs_.size() == kSingleNum || inputs_.size() == kDoubleNum);
MS_ASSERT(outputs_.size() == kSingleNum);
int conjugate = GetConjugate();
if (conjugate) {
MS_LOG(ERROR) << "Transpose conjugate is not support currently";
return RET_ERROR;
}
std::vector<int> perm;
for (size_t i = 0; i < GetPerm().size(); i++) {
perm.push_back(GetPerm().at(i));

View File

@ -35,13 +35,11 @@ class Transpose : public PrimitiveC {
explicit Transpose(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetPerm(const std::vector<int> &perm);
void SetConjugate(bool conjugate);
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
std::vector<int> GetPerm() const;
bool GetConjugate() const;
};
} // namespace lite
} // namespace mindspore

View File

@ -21,7 +21,6 @@
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/fp32/nchw2nhwc_fp32.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;

View File

@ -20,7 +20,6 @@
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/fp32/nchw2nhwc_fp32.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;

View File

@ -20,7 +20,6 @@
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/fp32/nchw2nhwc_fp32.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;

View File

@ -20,7 +20,6 @@
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/fp32/nchw2nhwc_fp32.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;

View File

@ -35,7 +35,6 @@ TEST_F(TestTfliteParserTranspose, OpType) {
TEST_F(TestTfliteParserTranspose, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTranspose(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsTranspose();
ASSERT_EQ(val->conjugate, false);
std::vector<int32_t> perm = {1, 0};
ASSERT_EQ(val->perm, perm);
}

View File

@ -203,7 +203,7 @@ int Benchmark::ReadTensorData(std::ifstream &in_file_stream, const std::string &
}
auto *check_tensor = new (std::nothrow) CheckTensor(dims, data, strings_data);
if (check_tensor == nullptr) {
MS_LOG(ERROR) << "Now CheckTensor failed, tensor name: " << tensor_name;
MS_LOG(ERROR) << "New CheckTensor failed, tensor name: " << tensor_name;
return RET_ERROR;
}
this->benchmark_data_.insert(std::make_pair(tensor_name, check_tensor));

View File

@ -193,7 +193,7 @@ class MS_API Benchmark {
auto tolerance = absoluteTolerance + relativeTolerance * fabs(calibTensor->data.at(j));
auto absoluteError = std::fabs(msTensorData[j] - calibTensor->data.at(j));
if (absoluteError > tolerance) {
if (fabs(calibTensor->data.at(j)) == 0) {
if (fabs(calibTensor->data.at(j) - 0.0f) < FLT_EPSILON) {
if (absoluteError > 1e-5) {
meanError += absoluteError;
errorCount++;

View File

@ -155,7 +155,6 @@ STATUS MatMulBiasAddFusionPass::InsertTransposeNode(MetaGraphT *graph, const std
MS_LOG(ERROR) << "new transposeParam failed";
return RET_ERROR;
}
transposeParam->conjugate = false;
transposeParam->perm = {1, 0};
transNode->primitive->value.value = transposeParam.release();
matmulOpIter =

View File

@ -53,13 +53,18 @@ class MatMulBiasAddFusionPass : public FusionPass {
size_t id = 0;
OpDefCopyer TransposeOpCopyer = [](CNodeT *inOpDef) -> std::unique_ptr<CNodeT> {
std::unique_ptr<CNodeT> newOpDef(new (std::nothrow) CNodeT);
auto newOpDef = std::make_unique<schema::CNodeT>();
if (newOpDef == nullptr) {
MS_LOG(ERROR) << "new OpDefT failed";
MS_LOG(ERROR) << "new CNodeT failed";
return nullptr;
}
newOpDef->name = inOpDef->name;
newOpDef->quantType = inOpDef->quantType;
newOpDef->primitive = std::make_unique<schema::PrimitiveT>();
if (newOpDef->primitive == nullptr) {
MS_LOG(ERROR) << "new PrimitiveT failed";
return nullptr;
}
newOpDef->primitive->value.type = schema::PrimitiveType_Transpose;
auto transposeParam = new (std::nothrow) TransposeT;
if (transposeParam == nullptr) {
@ -68,7 +73,6 @@ class MatMulBiasAddFusionPass : public FusionPass {
}
auto inParam = inOpDef->primitive->value.AsTranspose();
MS_ASSERT(inParam != nullptr);
transposeParam->conjugate = inParam->conjugate;
transposeParam->perm.resize(inParam->perm.size());
std::transform(inParam->perm.begin(), inParam->perm.end(), transposeParam->perm.begin(),
[](const int32_t ele) { return ele; });

View File

@ -23,7 +23,7 @@
namespace mindspore {
namespace lite {
enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW };
enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW, kNONE };
class FormatTransPass : public GraphPass {
public:

View File

@ -33,7 +33,6 @@ PrimitiveC *CaffePermuteParser::ParseLitePrimitive(const caffe::LayerParameter &
for (int i = 0; i < num_order_dims; ++i) {
attr->perm[i] = (int32_t)permuteParam.order()[i];
}
attr->conjugate = false;
auto primitive = std::make_unique<schema::PrimitiveT>();
primitive->value.type = schema::PrimitiveType_Transpose;
primitive->value.value = attr.release();

View File

@ -28,7 +28,6 @@ lite::PrimitiveC *OnnxTransposeParser::ParseLitePrimitive(const onnx::GraphProto
return nullptr;
}
attr->conjugate = false;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axes" || attribute_name == "perm") {

View File

@ -42,7 +42,6 @@ STATUS TFTransposeParser::Parse(const tensorflow::NodeDef &tf_op,
return RET_NULL_PTR;
}
attr->conjugate = false;
if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) {
MS_LOG(ERROR) << "Find Transpose input perm failed";
return RET_ERROR;

View File

@ -40,7 +40,6 @@ PrimitiveC *TfliteTransposeParser::ParseLitePrimitive(const std::unique_ptr<tfli
return nullptr;
}
attr->conjugate = false;
primitive->value.type = schema::PrimitiveType_Transpose;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());

View File

@ -210,7 +210,7 @@ STATUS getPaddingParam(const std::unique_ptr<tflite::TensorT> &tensor, schema::P
return RET_ERROR;
}
if (tensor->shape.empty()) {
MS_LOG(DEBUG) << "the tensor's shape is dynamic, which obtain nly when running.";
MS_LOG(DEBUG) << "the tensor's shape is dynamic, which obtain only when running.";
return RET_NO_CHANGE;
}
int padUp = 0;