clean code

This commit is contained in:
huanghui 2022-08-04 09:49:43 +08:00
parent cf07fbc41b
commit 5fe2ba9d42
7 changed files with 14 additions and 14 deletions

View File

@ -130,7 +130,7 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
void ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNodePtr &new_parameter);
std::vector<AnfNodePtr> outputs() const;
CNodePtr NewCNode(std::vector<AnfNodePtr> &&inputs) override;
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs) override;
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>()) override;
CNodePtr NewCNodeWithInfos(const std::vector<AnfNodePtr> &inputs, const CNodePtr &ori_cnode = nullptr);
void CreateKernelInfoFromNewParameter(const CNodePtr &cnode) const;
CNodePtr NewCNode(const CNodePtr &cnode);

View File

@ -1448,7 +1448,6 @@ const TypeId AbstractSparseTensor::GetTensorTypeIdAt(size_t index) const {
if (index >= shape_idx || index < 0) {
MS_LOG(EXCEPTION) << "Index must be in range of [0, " << shape_idx << "), but got " << index << " for "
<< ToString();
return kTypeUnknown;
}
auto abs_tensor = GetAbsPtrAt<abstract::AbstractTensorPtr>(index);
MS_EXCEPTION_IF_NULL(abs_tensor);
@ -1459,7 +1458,6 @@ const TypeId AbstractSparseTensor::GetShapeTypeIdAt(size_t index) const {
if (index >= shape()->size() || index < 0) {
MS_LOG(EXCEPTION) << "Index must be in range of [0, " << shape()->size() << "), but got " << index << " for "
<< ToString();
return kTypeUnknown;
}
return shape()->elements()[index]->BuildType()->type_id();
}

View File

@ -924,9 +924,7 @@ AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const Primitive
ShapeVector lengths_shape = lengths->shape()->shape();
ShapeVector lengths_shape_min = lengths->shape()->min_shape();
ShapeVector lengths_shape_max = lengths->shape()->max_shape();
if (!lengths_shape_max.empty() && !lengths_shape_min.empty()) {
lengths_shape_min.push_back(maxlen_value);
lengths_shape_max.push_back(maxlen_value);

View File

@ -17,6 +17,7 @@
#define MINDSPORE_CORE_BASE_COMPLEX_STORAGE_H_
#include "base/float16.h"
#include "utils/ms_utils.h"
namespace mindspore {
@ -62,7 +63,12 @@ struct alignas(sizeof(T) * 2) ComplexStorage {
template <typename T>
inline bool operator==(const ComplexStorage<T> &lhs, const ComplexStorage<T> &rhs) {
return (lhs.real_ - rhs.real_ == 0) && (lhs.imag_ - rhs.imag_ == 0);
if constexpr (std::is_same_v<T, double>) {
return common::IsDoubleEqual(lhs.real_, rhs.real_) && common::IsDoubleEqual(lhs.imag_, rhs.imag_);
} else if constexpr (std::is_same_v<T, float>) {
return common::IsFloatEqual(lhs.real_, rhs.real_) && common::IsFloatEqual(lhs.imag_, rhs.imag_);
}
return (lhs.real_ == rhs.real_) && (lhs.imag_ == rhs.imag_);
}
template <typename T>

View File

@ -17,7 +17,6 @@
#ifndef MINDSPORE_CORE_IR_DTYPE_NUMBER_H_
#define MINDSPORE_CORE_IR_DTYPE_NUMBER_H_
#include <iostream>
#include <map>
#include <memory>
#include <sstream>

View File

@ -236,12 +236,12 @@ class MS_CORE_API Object : public Type {
//
// TypeId name map
//
const mindspore::HashMap<TypeId, std::string> type_name_map = {
inline const mindspore::HashMap<TypeId, std::string> type_name_map = {
{kNumberTypeBool, "bool_"}, {kNumberTypeInt8, "int8"}, {kNumberTypeUInt8, "uint8"},
{kNumberTypeInt16, "int16"}, {kNumberTypeInt32, "int32"}, {kNumberTypeInt64, "int64"},
{kNumberTypeFloat16, "float16"}, {kNumberTypeFloat32, "float32"}, {kNumberTypeFloat64, "float64"}};
const mindspore::HashMap<TypeId, int> type_priority_map = {
inline const mindspore::HashMap<TypeId, int> type_priority_map = {
{kNumberTypeBool, 0}, {kNumberTypeUInt8, 1}, {kNumberTypeInt8, 2},
{kNumberTypeInt16, 3}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 5},
{kNumberTypeFloat16, 6}, {kNumberTypeFloat32, 7}, {kNumberTypeFloat64, 8}};

View File

@ -298,7 +298,7 @@ bool FuncGraphManager::recursive(const FuncGraphPtr &fg) const {
std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(const FuncGraphPtr &fg) const {
MS_EXCEPTION_IF_NULL(fg);
if (recursive(fg)) {
if (!recursive_->recursive_map().count(fg)) {
if (recursive_->recursive_map().count(fg) == 0) {
auto trace = std::list<FuncGraphPtr>();
recursive_->CheckRecursiveGraphs(fg, &trace);
}
@ -851,10 +851,9 @@ void FuncGraphTransaction::AddEdge(const AnfNodePtr &src_node, const AnfNodePtr
void FuncGraphTransaction::Commit() { manager_->CommitChanges(std::move(changes_)); }
DepComputer::DepComputer(const FuncGraphManager *const manager) : manager_(manager) {
DepComputer::DepComputer(const FuncGraphManager *const manager) : manager_(manager), validate_(false) {
MS_EXCEPTION_IF_NULL(manager_);
manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer);
validate_ = false;
}
void DepComputer::Recompute() {
@ -939,7 +938,7 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) {
for (auto &dep : deps) {
auto parent_deps = this->manager_->func_graph_parents_total(dep);
for (auto &p_d : parent_deps) {
if (deps_copy.count(p_d)) {
if (deps_copy.count(p_d) > 0) {
(void)deps_copy.erase(p_d);
}
}
@ -1085,7 +1084,7 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<F
CheckRecursiveGraphs(iter->first, trace);
}
trace->pop_back();
if (!recursive_map_.count(fg)) {
if (recursive_map_.count(fg) == 0) {
recursive_map_[fg] = nullptr;
}
}