forked from mindspore-Ecosystem/mindspore
clean code
This commit is contained in:
parent
cf07fbc41b
commit
5fe2ba9d42
|
@ -130,7 +130,7 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
|
|||
void ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNodePtr &new_parameter);
|
||||
std::vector<AnfNodePtr> outputs() const;
|
||||
CNodePtr NewCNode(std::vector<AnfNodePtr> &&inputs) override;
|
||||
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs) override;
|
||||
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>()) override;
|
||||
CNodePtr NewCNodeWithInfos(const std::vector<AnfNodePtr> &inputs, const CNodePtr &ori_cnode = nullptr);
|
||||
void CreateKernelInfoFromNewParameter(const CNodePtr &cnode) const;
|
||||
CNodePtr NewCNode(const CNodePtr &cnode);
|
||||
|
|
|
@ -1448,7 +1448,6 @@ const TypeId AbstractSparseTensor::GetTensorTypeIdAt(size_t index) const {
|
|||
if (index >= shape_idx || index < 0) {
|
||||
MS_LOG(EXCEPTION) << "Index must be in range of [0, " << shape_idx << "), but got " << index << " for "
|
||||
<< ToString();
|
||||
return kTypeUnknown;
|
||||
}
|
||||
auto abs_tensor = GetAbsPtrAt<abstract::AbstractTensorPtr>(index);
|
||||
MS_EXCEPTION_IF_NULL(abs_tensor);
|
||||
|
@ -1459,7 +1458,6 @@ const TypeId AbstractSparseTensor::GetShapeTypeIdAt(size_t index) const {
|
|||
if (index >= shape()->size() || index < 0) {
|
||||
MS_LOG(EXCEPTION) << "Index must be in range of [0, " << shape()->size() << "), but got " << index << " for "
|
||||
<< ToString();
|
||||
return kTypeUnknown;
|
||||
}
|
||||
return shape()->elements()[index]->BuildType()->type_id();
|
||||
}
|
||||
|
|
|
@ -924,9 +924,7 @@ AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const Primitive
|
|||
|
||||
ShapeVector lengths_shape = lengths->shape()->shape();
|
||||
ShapeVector lengths_shape_min = lengths->shape()->min_shape();
|
||||
|
||||
ShapeVector lengths_shape_max = lengths->shape()->max_shape();
|
||||
|
||||
if (!lengths_shape_max.empty() && !lengths_shape_min.empty()) {
|
||||
lengths_shape_min.push_back(maxlen_value);
|
||||
lengths_shape_max.push_back(maxlen_value);
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#define MINDSPORE_CORE_BASE_COMPLEX_STORAGE_H_
|
||||
|
||||
#include "base/float16.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
|
@ -62,7 +63,12 @@ struct alignas(sizeof(T) * 2) ComplexStorage {
|
|||
|
||||
template <typename T>
|
||||
inline bool operator==(const ComplexStorage<T> &lhs, const ComplexStorage<T> &rhs) {
|
||||
return (lhs.real_ - rhs.real_ == 0) && (lhs.imag_ - rhs.imag_ == 0);
|
||||
if constexpr (std::is_same_v<T, double>) {
|
||||
return common::IsDoubleEqual(lhs.real_, rhs.real_) && common::IsDoubleEqual(lhs.imag_, rhs.imag_);
|
||||
} else if constexpr (std::is_same_v<T, float>) {
|
||||
return common::IsFloatEqual(lhs.real_, rhs.real_) && common::IsFloatEqual(lhs.imag_, rhs.imag_);
|
||||
}
|
||||
return (lhs.real_ == rhs.real_) && (lhs.imag_ == rhs.imag_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
#ifndef MINDSPORE_CORE_IR_DTYPE_NUMBER_H_
|
||||
#define MINDSPORE_CORE_IR_DTYPE_NUMBER_H_
|
||||
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
|
|
|
@ -236,12 +236,12 @@ class MS_CORE_API Object : public Type {
|
|||
//
|
||||
// TypeId name map
|
||||
//
|
||||
const mindspore::HashMap<TypeId, std::string> type_name_map = {
|
||||
inline const mindspore::HashMap<TypeId, std::string> type_name_map = {
|
||||
{kNumberTypeBool, "bool_"}, {kNumberTypeInt8, "int8"}, {kNumberTypeUInt8, "uint8"},
|
||||
{kNumberTypeInt16, "int16"}, {kNumberTypeInt32, "int32"}, {kNumberTypeInt64, "int64"},
|
||||
{kNumberTypeFloat16, "float16"}, {kNumberTypeFloat32, "float32"}, {kNumberTypeFloat64, "float64"}};
|
||||
|
||||
const mindspore::HashMap<TypeId, int> type_priority_map = {
|
||||
inline const mindspore::HashMap<TypeId, int> type_priority_map = {
|
||||
{kNumberTypeBool, 0}, {kNumberTypeUInt8, 1}, {kNumberTypeInt8, 2},
|
||||
{kNumberTypeInt16, 3}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 5},
|
||||
{kNumberTypeFloat16, 6}, {kNumberTypeFloat32, 7}, {kNumberTypeFloat64, 8}};
|
||||
|
|
|
@ -298,7 +298,7 @@ bool FuncGraphManager::recursive(const FuncGraphPtr &fg) const {
|
|||
std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(const FuncGraphPtr &fg) const {
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
if (recursive(fg)) {
|
||||
if (!recursive_->recursive_map().count(fg)) {
|
||||
if (recursive_->recursive_map().count(fg) == 0) {
|
||||
auto trace = std::list<FuncGraphPtr>();
|
||||
recursive_->CheckRecursiveGraphs(fg, &trace);
|
||||
}
|
||||
|
@ -851,10 +851,9 @@ void FuncGraphTransaction::AddEdge(const AnfNodePtr &src_node, const AnfNodePtr
|
|||
|
||||
void FuncGraphTransaction::Commit() { manager_->CommitChanges(std::move(changes_)); }
|
||||
|
||||
DepComputer::DepComputer(const FuncGraphManager *const manager) : manager_(manager) {
|
||||
DepComputer::DepComputer(const FuncGraphManager *const manager) : manager_(manager), validate_(false) {
|
||||
MS_EXCEPTION_IF_NULL(manager_);
|
||||
manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer);
|
||||
validate_ = false;
|
||||
}
|
||||
|
||||
void DepComputer::Recompute() {
|
||||
|
@ -939,7 +938,7 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) {
|
|||
for (auto &dep : deps) {
|
||||
auto parent_deps = this->manager_->func_graph_parents_total(dep);
|
||||
for (auto &p_d : parent_deps) {
|
||||
if (deps_copy.count(p_d)) {
|
||||
if (deps_copy.count(p_d) > 0) {
|
||||
(void)deps_copy.erase(p_d);
|
||||
}
|
||||
}
|
||||
|
@ -1085,7 +1084,7 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<F
|
|||
CheckRecursiveGraphs(iter->first, trace);
|
||||
}
|
||||
trace->pop_back();
|
||||
if (!recursive_map_.count(fg)) {
|
||||
if (recursive_map_.count(fg) == 0) {
|
||||
recursive_map_[fg] = nullptr;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue