forked from mindspore-Ecosystem/mindspore
!9837 [MSLITE][Develop] fix code review
From: @sunsuodong Reviewed-by: @zhanghaibo5,@zhang_xue_tong Signed-off-by: @zhang_xue_tong
This commit is contained in:
commit
920d34db0b
|
@ -781,7 +781,7 @@ table Reduce {
|
|||
|
||||
table Transpose {
|
||||
perm: [int];
|
||||
conjugate: bool = false;
|
||||
conjugate: bool = false; // DEPRECATED
|
||||
}
|
||||
|
||||
table Squeeze {
|
||||
|
|
|
@ -39,7 +39,6 @@ OpParameter *PopulateTransposeParameter(const mindspore::lite::PrimitiveC *primi
|
|||
transpose_param->perm_[i++] = *iter;
|
||||
}
|
||||
transpose_param->num_axes_ = i;
|
||||
transpose_param->conjugate_ = param->GetConjugate();
|
||||
return reinterpret_cast<OpParameter *>(transpose_param);
|
||||
}
|
||||
|
||||
|
|
|
@ -27,10 +27,7 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
std::vector<int> Transpose::GetPerm() const { return this->primitive_->value.AsTranspose()->perm; }
|
||||
bool Transpose::GetConjugate() const { return this->primitive_->value.AsTranspose()->conjugate; }
|
||||
|
||||
void Transpose::SetPerm(const std::vector<int> &perm) { this->primitive_->value.AsTranspose()->perm = perm; }
|
||||
void Transpose::SetConjugate(bool conjugate) { this->primitive_->value.AsTranspose()->conjugate = conjugate; }
|
||||
|
||||
int Transpose::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
|
@ -83,7 +80,6 @@ std::vector<int> Transpose::GetPerm() const {
|
|||
auto fb_vector = this->primitive_->value_as_Transpose()->perm();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
bool Transpose::GetConjugate() const { return this->primitive_->value_as_Transpose()->conjugate(); }
|
||||
|
||||
int Transpose::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
|
@ -127,11 +123,6 @@ int Transpose::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> o
|
|||
MS_ASSERT(inputs_.size() == kSingleNum || inputs_.size() == kDoubleNum);
|
||||
MS_ASSERT(outputs_.size() == kSingleNum);
|
||||
|
||||
int conjugate = GetConjugate();
|
||||
if (conjugate) {
|
||||
MS_LOG(ERROR) << "Transpose conjugate is not support currently";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<int> perm;
|
||||
for (size_t i = 0; i < GetPerm().size(); i++) {
|
||||
perm.push_back(GetPerm().at(i));
|
||||
|
|
|
@ -35,13 +35,11 @@ class Transpose : public PrimitiveC {
|
|||
explicit Transpose(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
void SetPerm(const std::vector<int> &perm);
|
||||
void SetConjugate(bool conjugate);
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
std::vector<int> GetPerm() const;
|
||||
bool GetConjugate() const;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/runtime/kernel/arm/fp32/nchw2nhwc_fp32.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/runtime/kernel/arm/fp32/nchw2nhwc_fp32.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/runtime/kernel/arm/fp32/nchw2nhwc_fp32.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/runtime/kernel/arm/fp32/nchw2nhwc_fp32.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
|
|
|
@ -35,7 +35,6 @@ TEST_F(TestTfliteParserTranspose, OpType) {
|
|||
TEST_F(TestTfliteParserTranspose, AttrValue) {
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTranspose(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsTranspose();
|
||||
ASSERT_EQ(val->conjugate, false);
|
||||
std::vector<int32_t> perm = {1, 0};
|
||||
ASSERT_EQ(val->perm, perm);
|
||||
}
|
||||
|
|
|
@ -203,7 +203,7 @@ int Benchmark::ReadTensorData(std::ifstream &in_file_stream, const std::string &
|
|||
}
|
||||
auto *check_tensor = new (std::nothrow) CheckTensor(dims, data, strings_data);
|
||||
if (check_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Now CheckTensor failed, tensor name: " << tensor_name;
|
||||
MS_LOG(ERROR) << "New CheckTensor failed, tensor name: " << tensor_name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->benchmark_data_.insert(std::make_pair(tensor_name, check_tensor));
|
||||
|
|
|
@ -193,7 +193,7 @@ class MS_API Benchmark {
|
|||
auto tolerance = absoluteTolerance + relativeTolerance * fabs(calibTensor->data.at(j));
|
||||
auto absoluteError = std::fabs(msTensorData[j] - calibTensor->data.at(j));
|
||||
if (absoluteError > tolerance) {
|
||||
if (fabs(calibTensor->data.at(j)) == 0) {
|
||||
if (fabs(calibTensor->data.at(j) - 0.0f) < FLT_EPSILON) {
|
||||
if (absoluteError > 1e-5) {
|
||||
meanError += absoluteError;
|
||||
errorCount++;
|
||||
|
|
|
@ -155,7 +155,6 @@ STATUS MatMulBiasAddFusionPass::InsertTransposeNode(MetaGraphT *graph, const std
|
|||
MS_LOG(ERROR) << "new transposeParam failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
transposeParam->conjugate = false;
|
||||
transposeParam->perm = {1, 0};
|
||||
transNode->primitive->value.value = transposeParam.release();
|
||||
matmulOpIter =
|
||||
|
|
|
@ -53,13 +53,18 @@ class MatMulBiasAddFusionPass : public FusionPass {
|
|||
size_t id = 0;
|
||||
|
||||
OpDefCopyer TransposeOpCopyer = [](CNodeT *inOpDef) -> std::unique_ptr<CNodeT> {
|
||||
std::unique_ptr<CNodeT> newOpDef(new (std::nothrow) CNodeT);
|
||||
auto newOpDef = std::make_unique<schema::CNodeT>();
|
||||
if (newOpDef == nullptr) {
|
||||
MS_LOG(ERROR) << "new OpDefT failed";
|
||||
MS_LOG(ERROR) << "new CNodeT failed";
|
||||
return nullptr;
|
||||
}
|
||||
newOpDef->name = inOpDef->name;
|
||||
newOpDef->quantType = inOpDef->quantType;
|
||||
newOpDef->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (newOpDef->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new PrimitiveT failed";
|
||||
return nullptr;
|
||||
}
|
||||
newOpDef->primitive->value.type = schema::PrimitiveType_Transpose;
|
||||
auto transposeParam = new (std::nothrow) TransposeT;
|
||||
if (transposeParam == nullptr) {
|
||||
|
@ -68,7 +73,6 @@ class MatMulBiasAddFusionPass : public FusionPass {
|
|||
}
|
||||
auto inParam = inOpDef->primitive->value.AsTranspose();
|
||||
MS_ASSERT(inParam != nullptr);
|
||||
transposeParam->conjugate = inParam->conjugate;
|
||||
transposeParam->perm.resize(inParam->perm.size());
|
||||
std::transform(inParam->perm.begin(), inParam->perm.end(), transposeParam->perm.begin(),
|
||||
[](const int32_t ele) { return ele; });
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW };
|
||||
enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW, kNONE };
|
||||
|
||||
class FormatTransPass : public GraphPass {
|
||||
public:
|
||||
|
|
|
@ -33,7 +33,6 @@ PrimitiveC *CaffePermuteParser::ParseLitePrimitive(const caffe::LayerParameter &
|
|||
for (int i = 0; i < num_order_dims; ++i) {
|
||||
attr->perm[i] = (int32_t)permuteParam.order()[i];
|
||||
}
|
||||
attr->conjugate = false;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
primitive->value.type = schema::PrimitiveType_Transpose;
|
||||
primitive->value.value = attr.release();
|
||||
|
|
|
@ -28,7 +28,6 @@ lite::PrimitiveC *OnnxTransposeParser::ParseLitePrimitive(const onnx::GraphProto
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
attr->conjugate = false;
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "axes" || attribute_name == "perm") {
|
||||
|
|
|
@ -42,7 +42,6 @@ STATUS TFTransposeParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
attr->conjugate = false;
|
||||
if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) {
|
||||
MS_LOG(ERROR) << "Find Transpose input perm failed";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -40,7 +40,6 @@ PrimitiveC *TfliteTransposeParser::ParseLitePrimitive(const std::unique_ptr<tfli
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
attr->conjugate = false;
|
||||
primitive->value.type = schema::PrimitiveType_Transpose;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
|
|
|
@ -210,7 +210,7 @@ STATUS getPaddingParam(const std::unique_ptr<tflite::TensorT> &tensor, schema::P
|
|||
return RET_ERROR;
|
||||
}
|
||||
if (tensor->shape.empty()) {
|
||||
MS_LOG(DEBUG) << "the tensor's shape is dynamic, which obtain nly when running.";
|
||||
MS_LOG(DEBUG) << "the tensor's shape is dynamic, which obtain only when running.";
|
||||
return RET_NO_CHANGE;
|
||||
}
|
||||
int padUp = 0;
|
||||
|
|
Loading…
Reference in New Issue