fix protobuf cve and clean code

This commit is contained in:
changzherui 2022-03-21 10:03:38 +08:00
parent 1dba54470c
commit a584ac5d56
9 changed files with 194 additions and 16 deletions

View File

@ -44,6 +44,12 @@ else()
set(MD5 "1a6274bc4a65b55a6fa70e264d796490")
endif()
if(BUILD_LITE)
set(PROTOBUF_PATCH_ROOT ${TOP_DIR}/third_party/patch/protobuf)
else()
set(PROTOBUF_PATCH_ROOT ${CMAKE_SOURCE_DIR}/third_party/patch/protobuf)
endif()
mindspore_add_pkg(protobuf
VER 3.13.0
LIBS protobuf
@ -51,7 +57,8 @@ mindspore_add_pkg(protobuf
URL ${REQ_URL}
MD5 ${MD5}
CMAKE_PATH cmake/
CMAKE_OPTION -Dprotobuf_BUILD_TESTS=OFF -Dprotobuf_BUILD_SHARED_LIBS=OFF -DCMAKE_BUILD_TYPE=Release)
CMAKE_OPTION -Dprotobuf_BUILD_TESTS=OFF -Dprotobuf_BUILD_SHARED_LIBS=OFF -DCMAKE_BUILD_TYPE=Release
PATCHES ${PROTOBUF_PATCH_ROOT}/CVE-2021-22570.patch)
include_directories(${protobuf_INC})
add_library(mindspore::protobuf ALIAS protobuf::protobuf)

View File

@ -44,6 +44,12 @@ else()
set(MD5 "1a6274bc4a65b55a6fa70e264d796490")
endif()
if(BUILD_LITE)
set(PROTOBUF_PATCH_ROOT ${TOP_DIR}/third_party/patch/protobuf)
else()
set(PROTOBUF_PATCH_ROOT ${CMAKE_SOURCE_DIR}/third_party/patch/protobuf)
endif()
mindspore_add_pkg(protobuf_arm
VER 3.13.0
LIBS protobuf
@ -56,7 +62,8 @@ mindspore_add_pkg(protobuf_arm
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-Dprotobuf_WITH_ZLIB=OFF)
-Dprotobuf_WITH_ZLIB=OFF
PATCHES ${PROTOBUF_PATCH_ROOT}/CVE-2021-22570.patch)
include_directories(${protobuf_arm_INC})
add_library(mindspore::protobuf_arm ALIAS protobuf_arm::protobuf)

View File

@ -173,8 +173,8 @@ class COMMON_EXPORT DfGraphConvertor {
void SetNodeInput(AnfNodePtr node);
void SetOpControlInput(const AnfNodePtr &node);
void UpdateOpDesc(AnfNodePtr node);
void SetSubgraph(AnfNodePtr node);
void ProcessSubgraph(AnfNodePtr node, const std::vector<AnfNodePtr> &inputs);
void SetSubgraph(const AnfNodePtr &node);
void ProcessSubgraph(const AnfNodePtr &node, const std::vector<AnfNodePtr> &inputs);
void BuildSaveCheckpointGraph();
void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt);
void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const;

View File

@ -398,7 +398,7 @@ void DfGraphConvertor::InitParamWithData(const TensorOrderMap &tensors) {
// convert all parameter need initialize to variable
DfGraphConvertor &DfGraphConvertor::InitParam(const TensorOrderMap &tensors) {
size_t input_idx = 0;
if (error_ != 0) {
if (error_ != SUCCESS) {
return *this;
}
if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) {
@ -481,7 +481,7 @@ void DfGraphConvertor::BuildSaveCheckpointGraph() {
#endif
DfGraphConvertor &DfGraphConvertor::GenerateBroadcastGraph(const TensorOrderMap &tensors) {
if (error_ != 0) {
if (error_ != SUCCESS) {
return *this;
}
if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) {
@ -538,7 +538,7 @@ DfGraphConvertor &DfGraphConvertor::GenerateBroadcastGraph(const TensorOrderMap
}
DfGraphConvertor &DfGraphConvertor::GenerateCheckpointGraph() {
if (error_ != 0) {
if (error_ != SUCCESS) {
MS_LOG(ERROR) << "Generate checkpoint graph failed, found error code " << error_ << ".";
return *this;
}
@ -559,7 +559,7 @@ DfGraphConvertor &DfGraphConvertor::GenerateCheckpointGraph() {
}
DfGraphConvertor &DfGraphConvertor::ConvertAllNode() {
if (error_ != 0) {
if (error_ != SUCCESS) {
return *this;
}
if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) {
@ -588,7 +588,7 @@ DfGraphConvertor &DfGraphConvertor::ConvertAllNode() {
std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
for (auto &it : nodes) {
(void)Convert(it);
if (this->error_ != 0) {
if (this->error_ != SUCCESS) {
MS_LOG(ERROR) << "failed to convert node: " << it->DebugString() << ".";
}
}
@ -737,7 +737,7 @@ void SetupDatasetIterGetNextNode(const OperatorPtr &op) {
return;
}
void DfGraphConvertor::SetSubgraph(AnfNodePtr node) {
void DfGraphConvertor::SetSubgraph(const AnfNodePtr &node) {
if (!node->isa<CNode>()) {
return;
}
@ -850,7 +850,7 @@ void DfGraphConvertor::UpdateTupleOutCache() {
DfGraphConvertor &DfGraphConvertor::BuildGraph() {
SetupDatasetIterGetNextNode(dataset_iter_getnext_);
if (error_ != 0) {
if (error_ != SUCCESS) {
return *this;
}
@ -877,7 +877,7 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
UpdateOpDesc(it);
}
if (error_ == 0) {
if (error_ == SUCCESS) {
df_graph_ = make_shared<DfGraph>(anf_graph_->ToString());
} else {
return *this;
@ -1416,7 +1416,7 @@ void DfGraphConvertor::SetNodeInput(const AnfNodePtr node) {
DfGraphConvertor::SetOpInput(adpt, cnode);
}
void DfGraphConvertor::ProcessSubgraph(AnfNodePtr node, const std::vector<AnfNodePtr> &inputs) {
void DfGraphConvertor::ProcessSubgraph(const AnfNodePtr &node, const std::vector<AnfNodePtr> &inputs) {
if (!node->isa<CNode>() || GetCNodeFuncName(node->cast<CNodePtr>()) != "Partial") {
return;
}

View File

@ -194,7 +194,7 @@ class OpAdapter : public BaseOpAdapter {
return impl_->SetOpSubgraphFunc(op, index, branches);
}
int setSubgraph(const OperatorPtr &op, int index, std::shared_ptr<std::vector<DfGraph>> branches) override {
int setSubgraph(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<DfGraph>> &branches) override {
return static_cast<int>(SetOpSubgraphFunc(op, index, branches));
}

View File

@ -106,7 +106,7 @@ class BaseOpAdapter {
virtual ~BaseOpAdapter() {}
virtual OperatorPtr generate(const AnfNodePtr &anf) = 0;
virtual OperatorPtr generate(const std::string &type) { return std::make_shared<ge::Operator>(type); }
virtual int setSubgraph(const OperatorPtr &op, int index, std::shared_ptr<std::vector<DfGraph>> branches) = 0;
virtual int setSubgraph(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<DfGraph>> &branches) = 0;
virtual int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) = 0;
virtual int setInput(const OperatorPtr &op, int index, const OutHandler &handle) = 0;
virtual int setInput(const OperatorPtr &op, int index,

View File

@ -408,7 +408,7 @@ std::string TransformUtil::PrintGeTensor(const GeTensorPtr ge_tensor) {
}
MS_LOG(INFO) << "Ge Tensor data type is : " << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
switch (ge_tensor->GetTensorDesc().GetDataType()) {
switch (static_cast<int>(ge_tensor->GetTensorDesc().GetDataType())) {
case GeDataType::DT_UINT32:
ret = PrintVector(MakeVector<uint32_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
break;

View File

@ -579,6 +579,7 @@ if(MSLITE_ENABLE_CONVERTER OR MSLITE_ENABLE_RUNTIME_GLOG)
endif()
if(MSLITE_ENABLE_CONVERTER)
find_package(Patch)
include(${TOP_DIR}/cmake/external_libs/opencv.cmake)
include_directories(${PYTHON_INCLUDE_DIRS})
include(${TOP_DIR}/cmake/external_libs/eigen.cmake)

View File

@ -0,0 +1,163 @@
diff --git a/src/google/protobuf/descriptor.cc b/src/google/protobuf/descriptor.cc
index 9a448ffc8..40510b46c 100644
--- a/src/google/protobuf/descriptor.cc
+++ b/src/google/protobuf/descriptor.cc
@@ -1090,7 +1090,7 @@ inline void DescriptorPool::Tables::FindAllExtensions(
bool DescriptorPool::Tables::AddSymbol(const std::string& full_name,
Symbol symbol) {
- if (InsertIfNotPresent(&symbols_by_name_, full_name.c_str(), symbol)) {
+ if (InsertIfNotPresent(&symbols_by_name_, full_name, symbol)) {
symbols_after_checkpoint_.push_back(full_name.c_str());
return true;
} else {
@@ -1106,7 +1106,7 @@ bool FileDescriptorTables::AddAliasUnderParent(const void* parent,
}
bool DescriptorPool::Tables::AddFile(const FileDescriptor* file) {
- if (InsertIfNotPresent(&files_by_name_, file->name().c_str(), file)) {
+ if (InsertIfNotPresent(&files_by_name_, file->name(), file)) {
files_after_checkpoint_.push_back(file->name().c_str());
return true;
} else {
@@ -2628,6 +2628,8 @@ void Descriptor::DebugString(int depth, std::string* contents,
const Descriptor::ReservedRange* range = reserved_range(i);
if (range->end == range->start + 1) {
strings::SubstituteAndAppend(contents, "$0, ", range->start);
+ } else if (range->end > FieldDescriptor::kMaxNumber) {
+ strings::SubstituteAndAppend(contents, "$0 to max, ", range->start);
} else {
strings::SubstituteAndAppend(contents, "$0 to $1, ", range->start,
range->end - 1);
@@ -2831,6 +2833,8 @@ void EnumDescriptor::DebugString(
const EnumDescriptor::ReservedRange* range = reserved_range(i);
if (range->end == range->start) {
strings::SubstituteAndAppend(contents, "$0, ", range->start);
+ } else if (range->end == INT_MAX) {
+ strings::SubstituteAndAppend(contents, "$0 to max, ", range->start);
} else {
strings::SubstituteAndAppend(contents, "$0 to $1, ", range->start,
range->end);
@@ -4022,6 +4026,12 @@ bool DescriptorBuilder::AddSymbol(const std::string& full_name,
// Use its file as the parent instead.
if (parent == nullptr) parent = file_;
+ if (full_name.find('\0') != std::string::npos) {
+ AddError(full_name, proto, DescriptorPool::ErrorCollector::NAME,
+ "\"" + full_name + "\" contains null character.");
+ return false;
+ }
+
if (tables_->AddSymbol(full_name, symbol)) {
if (!file_tables_->AddAliasUnderParent(parent, name, symbol)) {
// This is only possible if there was already an error adding something of
@@ -4061,6 +4071,11 @@ bool DescriptorBuilder::AddSymbol(const std::string& full_name,
void DescriptorBuilder::AddPackage(const std::string& name,
const Message& proto,
const FileDescriptor* file) {
+ if (name.find('\0') != std::string::npos) {
+ AddError(name, proto, DescriptorPool::ErrorCollector::NAME,
+ "\"" + name + "\" contains null character.");
+ return;
+ }
if (tables_->AddSymbol(name, Symbol(file))) {
// Success. Also add parent package, if any.
std::string::size_type dot_pos = name.find_last_of('.');
@@ -4374,6 +4389,12 @@ FileDescriptor* DescriptorBuilder::BuildFileImpl(
}
result->pool_ = pool_;
+ if (result->name().find('\0') != std::string::npos) {
+ AddError(result->name(), proto, DescriptorPool::ErrorCollector::NAME,
+ "\"" + result->name() + "\" contains null character.");
+ return nullptr;
+ }
+
// Add to tables.
if (!tables_->AddFile(result)) {
AddError(proto.name(), proto, DescriptorPool::ErrorCollector::OTHER,
diff --git a/src/google/protobuf/descriptor_unittest.cc b/src/google/protobuf/descriptor_unittest.cc
index 6085a122a..56c180aa4 100644
--- a/src/google/protobuf/descriptor_unittest.cc
+++ b/src/google/protobuf/descriptor_unittest.cc
@@ -3786,6 +3786,45 @@ TEST_F(ValidationErrorTest, InvalidPackageName) {
"foo.proto: foo.$: NAME: \"$\" is not a valid identifier.\n");
}
+// 'str' is a static C-style string that may contain '\0'
+#define STATIC_STR(str) std::string((str), sizeof(str) - 1)
+
+TEST_F(ValidationErrorTest, NullCharSymbolName) {
+ BuildFileWithErrors(
+ "name: \"bar.proto\" "
+ "package: \"foo\""
+ "message_type { "
+ " name: '\\000\\001\\013.Bar' "
+ " field { name: \"foo\" number: 9 label:LABEL_OPTIONAL type:TYPE_INT32 "
+ "} "
+ "}",
+ STATIC_STR("bar.proto: foo.\0\x1\v.Bar: NAME: \"\0\x1\v.Bar\" is not a "
+ "valid identifier.\nbar.proto: foo.\0\x1\v.Bar: NAME: "
+ "\"\0\x1\v.Bar\" is not a valid identifier.\nbar.proto: "
+ "foo.\0\x1\v.Bar: NAME: \"\0\x1\v.Bar\" is not a valid "
+ "identifier.\nbar.proto: foo.\0\x1\v.Bar: NAME: "
+ "\"\0\x1\v.Bar\" is not a valid identifier.\nbar.proto: "
+ "foo.\0\x1\v.Bar.foo: NAME: \"foo.\0\x1\v.Bar.foo\" contains "
+ "null character.\nbar.proto: foo.\0\x1\v.Bar: NAME: "
+ "\"foo.\0\x1\v.Bar\" contains null character.\n"));
+}
+
+TEST_F(ValidationErrorTest, NullCharFileName) {
+ BuildFileWithErrors(
+ "name: \"bar\\000\\001\\013.proto\" "
+ "package: \"outer.foo\"",
+ STATIC_STR("bar\0\x1\v.proto: bar\0\x1\v.proto: NAME: "
+ "\"bar\0\x1\v.proto\" contains null character.\n"));
+}
+
+TEST_F(ValidationErrorTest, NullCharPackageName) {
+ BuildFileWithErrors(
+ "name: \"bar.proto\" "
+ "package: \"\\000\\001\\013.\"",
+ STATIC_STR("bar.proto: \0\x1\v.: NAME: \"\0\x1\v.\" contains null "
+ "character.\n"));
+}
+
TEST_F(ValidationErrorTest, MissingFileName) {
BuildFileWithErrors("",
@@ -4001,6 +4040,32 @@ TEST_F(ValidationErrorTest, ReservedFieldsDebugString) {
file->DebugString());
}
+TEST_F(ValidationErrorTest, DebugStringReservedRangeMax) {
+ const FileDescriptor* file = BuildFile(strings::Substitute(
+ "name: \"foo.proto\" "
+ "enum_type { "
+ " name: \"Bar\""
+ " value { name:\"BAR\" number:1 }"
+ " reserved_range { start: 5 end: $0 }"
+ "}"
+ "message_type {"
+ " name: \"Foo\""
+ " reserved_range { start: 5 end: $1 }"
+ "}",
+ std::numeric_limits<int>::max(), FieldDescriptor::kMaxNumber + 1));
+
+ ASSERT_EQ(
+ "syntax = \"proto2\";\n\n"
+ "enum Bar {\n"
+ " BAR = 1;\n"
+ " reserved 5 to max;\n"
+ "}\n\n"
+ "message Foo {\n"
+ " reserved 5 to max;\n"
+ "}\n\n",
+ file->DebugString());
+}
+
TEST_F(ValidationErrorTest, EnumReservedFieldError) {
BuildFileWithErrors(
"name: \"foo.proto\" "