fix code check

This commit is contained in:
lianliguang 2022-01-14 15:39:39 +08:00
parent 16bb659a6b
commit ffe4393e95
22 changed files with 36 additions and 163 deletions

View File

@ -24,6 +24,7 @@
"mindspore/mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.cc" "useStlAlgorithm" "mindspore/mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.cc" "useStlAlgorithm"
"mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/cast_gpu_kernel.cc" "unknownMacro" "mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/cast_gpu_kernel.cc" "unknownMacro"
"mindspore/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc" "nullPointerArithmeticRedundantCheck" "mindspore/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc" "nullPointerArithmeticRedundantCheck"
"mindspore/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc" "containerOutOfBounds"
# MindData # MindData
"mindspore/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc" "useStlAlgorithm" "mindspore/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc" "useStlAlgorithm"

View File

@ -826,7 +826,7 @@ REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Shard, prim::kPrimShard, InferImplShard, nullptr); REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Shard, prim::kPrimShard, InferImplShard, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs,
InferImplBroadcastGradientArgs, nullptr); InferImplBroadcastGradientArgs, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(MakeSiice, prim::kPrimMakeSlice, InferImplMakeSlice, nullptr); REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(MakeSlice, prim::kPrimMakeSlice, InferImplMakeSlice, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(SliceGetItem, prim::kPrimSliceGetItem, InferImplSliceGetItem, nullptr); REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(SliceGetItem, prim::kPrimSliceGetItem, InferImplSliceGetItem, nullptr);
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore

View File

@ -84,9 +84,9 @@ class Resource : public ResourceBase {
loop_size_ = size; loop_size_ = size;
} }
void set_is_load(bool flag) { is_load_ = flag; } void set_is_load(bool flag) { is_load_ = flag; }
bool is_load() { return is_load_; } bool is_load() const { return is_load_; }
bool vm_loop_flag() { return vm_loop_flag_; } bool vm_loop_flag() const { return vm_loop_flag_; }
int64_t loop_size() { return loop_size_; } int64_t loop_size() const { return loop_size_; }
void set_layout_map(const LayoutMap &layout_map) { layout_map_ = layout_map; } void set_layout_map(const LayoutMap &layout_map) { layout_map_ = layout_map; }
const LayoutMap &get_layout_map() const { return layout_map_; } const LayoutMap &get_layout_map() const { return layout_map_; }

View File

@ -294,7 +294,7 @@ class AsyncInferTask {
static AsyncInferTaskPtr MakeShared(const AsyncAbstractPtr &abstract, const std::string &thread = "") { static AsyncInferTaskPtr MakeShared(const AsyncAbstractPtr &abstract, const std::string &thread = "") {
std::string thread_id = thread; std::string thread_id = thread;
if (thread_id == "") { if (thread_id == "") {
thread_id = AnalysisSchedule::GetInstance().thread_id(); thread_id = AnalysisSchedule::thread_id();
} }
MS_EXCEPTION_IF_NULL(abstract); MS_EXCEPTION_IF_NULL(abstract);
auto ret = std::make_shared<AsyncInferTask>(thread_id, abstract); auto ret = std::make_shared<AsyncInferTask>(thread_id, abstract);
@ -387,10 +387,10 @@ class EvaluatorCacheMgr {
~EvaluatorCacheMgr() = default; ~EvaluatorCacheMgr() = default;
void Clear() { eval_result_cache_.clear(); } void Clear() { eval_result_cache_.clear(); }
const EvalResultCache &GetCache() { return eval_result_cache_; } const EvalResultCache &GetCache() const { return eval_result_cache_; }
EvalResultPtr GetValue(const AbstractBasePtrList &key) { return eval_result_cache_.get(key); } EvalResultPtr GetValue(const AbstractBasePtrList &key) { return eval_result_cache_.get(key); }
void SetValue(const AbstractBasePtrList &key, const EvalResultPtr &arg) { eval_result_cache_.set(key, arg); } void SetValue(const AbstractBasePtrList &key, const EvalResultPtr &arg) { eval_result_cache_.set(key, arg); }
size_t GetSize() { return eval_result_cache_.size(); } size_t GetSize() const { return eval_result_cache_.size(); }
private: private:
EvalResultCache eval_result_cache_; EvalResultCache eval_result_cache_;

View File

@ -616,7 +616,6 @@ class SideEffectFinder {
if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) { if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) {
if (tuple_indexes->empty()) { if (tuple_indexes->empty()) {
MS_LOG(EXCEPTION) << "Unexpected make_tuple: " << cnode->DebugString(2); MS_LOG(EXCEPTION) << "Unexpected make_tuple: " << cnode->DebugString(2);
return {EffectInfo::kDetected, false, false, false};
} }
// Pop out tuple index. // Pop out tuple index.
auto top_index = tuple_indexes->top(); auto top_index = tuple_indexes->top();

View File

@ -73,7 +73,7 @@ class Evaluator : public Base {
return nullptr; return nullptr;
} }
auto is_abstract = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](auto &arg) { auto is_abstract = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](auto &arg) -> bool {
if (arg->BuildType()->type_id() == kObjectTypeUndeterminedType) { if (arg->BuildType()->type_id() == kObjectTypeUndeterminedType) {
return true; return true;
} }
@ -94,7 +94,6 @@ class Evaluator : public Base {
EvaluatorCacheMgrPtr evaluator_cache_mgr() const { return evaluator_cache_mgr_; } EvaluatorCacheMgrPtr evaluator_cache_mgr() const { return evaluator_cache_mgr_; }
EvaluatorAttrCachePtr attr_cache() const { return attr_cache_; } EvaluatorAttrCachePtr attr_cache() const { return attr_cache_; }
const std::recursive_timed_mutex &eval_lock() const { return eval_lock_; } const std::recursive_timed_mutex &eval_lock() const { return eval_lock_; }
protected: protected:

View File

@ -89,7 +89,7 @@ class UniformPrimEvaluator final : public TrivialPrimEvaluator {
} }
} }
} }
~UniformPrimEvaluator() override = default; ~UniformPrimEvaluator() override { impl_ = nullptr; };
MS_DECLARE_PARENT(UniformPrimEvaluator, TrivialPrimEvaluator); MS_DECLARE_PARENT(UniformPrimEvaluator, TrivialPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
@ -122,7 +122,7 @@ class DoSignatureEvaluator final : public Evaluator {
~DoSignatureEvaluator() override = default; ~DoSignatureEvaluator() override = default;
MS_DECLARE_PARENT(DoSignatureEvaluator, Evaluator); MS_DECLARE_PARENT(DoSignatureEvaluator, Evaluator);
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
const AnfNodeConfigPtr &out_config = nullptr) override; const AnfNodeConfigPtr &out_config) override;
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override { EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
@ -138,7 +138,7 @@ class UnpackGraphEvaluator final : public Evaluator {
~UnpackGraphEvaluator() override = default; ~UnpackGraphEvaluator() override = default;
MS_DECLARE_PARENT(UnpackGraphEvaluator, Evaluator); MS_DECLARE_PARENT(UnpackGraphEvaluator, Evaluator);
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
const AnfNodeConfigPtr &out_config = nullptr) override; const AnfNodeConfigPtr &out_config) override;
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override { EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
@ -155,7 +155,7 @@ class MixedPrecisionCastEvaluator final : public Evaluator {
~MixedPrecisionCastEvaluator() override = default; ~MixedPrecisionCastEvaluator() override = default;
MS_DECLARE_PARENT(MixedPrecisionCastEvaluator, Evaluator); MS_DECLARE_PARENT(MixedPrecisionCastEvaluator, Evaluator);
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
const AnfNodeConfigPtr &out_config = nullptr) override; const AnfNodeConfigPtr &out_config) override;
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override { EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";

View File

@ -27,7 +27,6 @@
namespace mindspore { namespace mindspore {
namespace abstract { namespace abstract {
class StackFrame; class StackFrame;
using StackFramePtr = std::shared_ptr<StackFrame>;
using EvaluatorWeakPtr = std::weak_ptr<Evaluator>; using EvaluatorWeakPtr = std::weak_ptr<Evaluator>;
using BaseFuncGraphEvaluatorPtr = std::shared_ptr<BaseFuncGraphEvaluator>; using BaseFuncGraphEvaluatorPtr = std::shared_ptr<BaseFuncGraphEvaluator>;
@ -65,7 +64,7 @@ class StackFrame final : public Base {
// Return back from branch func graph. // Return back from branch func graph.
void Back(const AnalysisEnginePtr &engine, const StackFramePtr &last_stack_frame, const EvalResultPtr &eval_result); void Back(const AnalysisEnginePtr &engine, const StackFramePtr &last_stack_frame, const EvalResultPtr &eval_result);
bool Done() { return done_; } bool Done() const { return done_; }
AnfNodePtr &CurrentNode() { AnfNodePtr &CurrentNode() {
if (slot_index_ >= node_slots_.size()) { if (slot_index_ >= node_slots_.size()) {
@ -92,7 +91,7 @@ class StackFrame final : public Base {
AnalysisContextPtr current_context() const { return current_context_; } AnalysisContextPtr current_context() const { return current_context_; }
AnalysisContextPtr parent_context() const { return parent_context_; } AnalysisContextPtr parent_context() const { return parent_context_; }
const AbstractBasePtrList &args_abs_list() { return args_abs_list_; } const AbstractBasePtrList &args_abs_list() const { return args_abs_list_; }
void set_args_abs_list(const AbstractBasePtrList &&args_abs_list) { args_abs_list_ = args_abs_list; } void set_args_abs_list(const AbstractBasePtrList &&args_abs_list) { args_abs_list_ = args_abs_list; }
std::string ToString() const override { std::string ToString() const override {

View File

@ -63,7 +63,7 @@ using AttrValueMapPtr = std::shared_ptr<AttrValueMap>;
// the class to save evaluated result: abstract value and modified attribute // the class to save evaluated result: abstract value and modified attribute
class EvalResult : public Base { class EvalResult : public Base {
public: public:
EvalResult(AbstractBasePtr abs, AttrValueMapPtr attr) : abstract_(abs), attribute_(attr) {} EvalResult(const AbstractBasePtr &abs, const AttrValueMapPtr &attr) : abstract_(abs), attribute_(attr) {}
~EvalResult() override = default; ~EvalResult() override = default;
MS_DECLARE_PARENT(EvalResult, Base); MS_DECLARE_PARENT(EvalResult, Base);
AbstractBasePtr abstract() { return abstract_; } AbstractBasePtr abstract() { return abstract_; }
@ -261,7 +261,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
forward_count_ = 0; forward_count_ = 0;
enable_recursive_eval_ = (common::GetEnv("MS_DEV_RECURSIVE_EVAL") == "1"); enable_recursive_eval_ = (common::GetEnv("MS_DEV_RECURSIVE_EVAL") == "1");
} }
~AnalysisEngine() = default; virtual ~AnalysisEngine() = default;
// func_graph: The func_graph to analyze. // func_graph: The func_graph to analyze.
// args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase. // args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase.

View File

@ -71,7 +71,7 @@ class PrimitivePy : public Primitive {
py::dict RunInfer(const py::tuple &args); py::dict RunInfer(const py::tuple &args);
void RunCheck(const py::tuple &args); void RunCheck(const py::tuple &args);
py::object RunInferValue(const py::tuple &args); py::object RunInferValue(const py::tuple &args);
bool HasPyObj() { return python_obj_.operator bool(); } bool HasPyObj() const { return python_obj_.operator bool(); }
PrimitivePtr Clone() override; PrimitivePtr Clone() override;
PrimitivePyAdapterPtr adapter() const { return adapter_; } PrimitivePyAdapterPtr adapter() const { return adapter_; }
void set_bprop_cls_name(const std::string &name) { bprop_cls_name_ = name; } void set_bprop_cls_name(const std::string &name) { bprop_cls_name_ = name; }
@ -102,7 +102,7 @@ class PrimitivePyAdapter {
void set_hook(const py::function &hook); void set_hook(const py::function &hook);
void set_instance_name(const std::string &s); void set_instance_name(const std::string &s);
void set_attached_primitive(const PrimitivePyPtr &prim); void set_attached_primitive(const PrimitivePyPtr &prim);
PrimitivePyPtr attached_primitive() { return attached_primitive_.lock(); } PrimitivePyPtr attached_primitive() const { return attached_primitive_.lock(); }
std::string name() const { return name_; } std::string name() const { return name_; }
void set_name(const std::string &name) { name_ = name; } void set_name(const std::string &name) { name_ = name; }
const bool parse_info_ = true; const bool parse_info_ = true;

View File

@ -16,12 +16,6 @@
#include "pybind_api/ir/tensor_py.h" #include "pybind_api/ir/tensor_py.h"
#include <vector>
#include <sstream>
#include <string>
#include <utility>
#include <complex>
#include "pybind_api/api_register.h" #include "pybind_api/api_register.h"
#include "abstract/abstract_value.h" #include "abstract/abstract_value.h"
#include "utils/shape_utils.h" #include "utils/shape_utils.h"
@ -125,6 +119,8 @@ static std::string GetPyTypeFormat(TypeId data_type) {
return py::format_descriptor<std::complex<float>>::format(); return py::format_descriptor<std::complex<float>>::format();
case TypeId::kNumberTypeComplex128: case TypeId::kNumberTypeComplex128:
return py::format_descriptor<std::complex<double>>::format(); return py::format_descriptor<std::complex<double>>::format();
case TypeId::kMetaTypeType:
case TypeId::kMetaTypeEllipsis:
default: default:
MS_LOG(WARNING) << "Unsupported DataType " << data_type << "."; MS_LOG(WARNING) << "Unsupported DataType " << data_type << ".";
return ""; return "";
@ -182,7 +178,7 @@ class TensorDataNumpy : public TensorData {
} }
private: private:
void *buffer_data() { return buffer_.ptr; } void *buffer_data() const { return buffer_.ptr; }
// The internal buffer. // The internal buffer.
py::buffer_info buffer_; py::buffer_info buffer_;

View File

@ -15,8 +15,6 @@
*/ */
#include "pybind_api/random_normal/random_cpu_kernel.h" #include "pybind_api/random_normal/random_cpu_kernel.h"
#include <thread>
#include <memory>
#include "runtime/device/cpu/cpu_device_address.h" #include "runtime/device/cpu/cpu_device_address.h"
#include "ir/tensor.h" #include "ir/tensor.h"

View File

@ -1138,12 +1138,11 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
parser->Parse(control_nodes_, graphs, device_contexts, root_graph, func_graph_to_kernel_graphs); parser->Parse(control_nodes_, graphs, device_contexts, root_graph, func_graph_to_kernel_graphs);
runtime::KernelMapPosition outputs_order; runtime::KernelMapPosition outputs_order;
size_t outputs_num = 0;
const auto &root_output = const auto &root_output =
AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0, false, {prim::kPrimTupleGetItem}).first; AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0, false, {prim::kPrimTupleGetItem}).first;
size_t position = 0; size_t position = 0;
auto outputs = AnfAlgo::GetAllOutputWithIndex(root_output); auto outputs = AnfAlgo::GetAllOutputWithIndex(root_output);
outputs_num = outputs.size(); size_t outputs_num = outputs.size();
for (const auto &output : outputs) { for (const auto &output : outputs) {
if (outputs_order.count(output) == 0) { if (outputs_order.count(output) == 0) {
outputs_order[output] = {position++}; outputs_order[output] = {position++};

View File

@ -134,7 +134,7 @@ void FinalVM::Popsp() {
int64_t sp = retsp_.top(); int64_t sp = retsp_.top();
MS_LOG(DEBUG) << "Current sp:" << sp_ << ", before sp:" << sp << ", " << sp_ - sp; MS_LOG(DEBUG) << "Current sp:" << sp_ << ", before sp:" << sp << ", " << sp_ - sp;
if (sp_ >= sp) { if (sp_ >= sp) {
Pop(sp_ - sp + 1); Pop((sp_ - sp) + 1);
retsp_.pop(); retsp_.pop();
} else { } else {
MS_LOG(EXCEPTION) << "Stack point sp_:" << sp << " must be bigger than sp:" << sp_; MS_LOG(EXCEPTION) << "Stack point sp_:" << sp << " must be bigger than sp:" << sp_;

View File

@ -268,14 +268,12 @@ class MS_CORE_API PartialAbstractClosure final : public AbstractFuncAtom {
/// \brief Get the pre-provided arguments. /// \brief Get the pre-provided arguments.
/// ///
/// \return The pre-provided arguments. /// \return The pre-provided arguments.
const AbstractBasePtrList &args() { return args_spec_list_; } const AbstractBasePtrList &args() const { return args_spec_list_; }
ValuePtr RealBuildValue() const override { return fn_->BuildValue(); }
/// \brief Get the CNode this PartialAbstractClosure evaluated from. /// \brief Get the CNode this PartialAbstractClosure evaluated from.
/// ///
/// \return The CNode this PartialAbstractClosure evaluated from. /// \return The CNode this PartialAbstractClosure evaluated from.
AnfNodePtr node() { return node_.lock(); } AnfNodePtr node() const { return node_.lock(); }
/// \brief Set the CNode this PartialAbstractClosure evaluated from. /// \brief Set the CNode this PartialAbstractClosure evaluated from.
/// ///
@ -292,6 +290,9 @@ class MS_CORE_API PartialAbstractClosure final : public AbstractFuncAtom {
std::string ToString() const override; std::string ToString() const override;
protected:
ValuePtr RealBuildValue() const override { return fn_->BuildValue(); }
private: private:
AbstractFuncAtomPtr fn_; AbstractFuncAtomPtr fn_;
AbstractBasePtrList args_spec_list_; AbstractBasePtrList args_spec_list_;

View File

@ -147,7 +147,7 @@ class MS_CORE_API AbstractBase : public Base {
/// \brief Set the function, which prints the debug info. /// \brief Set the function, which prints the debug info.
/// ///
/// \param[in] trace_node_provider The function. /// \param[in] trace_node_provider The function.
static void set_trace_node_provider(TraceNodeProvider trace_node_provider) { static void set_trace_node_provider(const TraceNodeProvider &trace_node_provider) {
trace_node_provider_ = trace_node_provider; trace_node_provider_ = trace_node_provider;
} }

View File

@ -18,11 +18,6 @@
#include "abstract/dshape.h" #include "abstract/dshape.h"
#include <exception>
#include <iostream>
#include "utils/log_adapter.h"
namespace mindspore { namespace mindspore {
namespace abstract { namespace abstract {
namespace { namespace {

View File

@ -163,17 +163,17 @@ class MS_CORE_API Shape final : public BaseShape {
/// \brief Get shape dimensions. /// \brief Get shape dimensions.
/// ///
/// \return Shape dimensions. /// \return Shape dimensions.
const ShapeVector &shape() { return shape_; } const ShapeVector &shape() const { return shape_; }
/// \brief Get minimum shape dimensions. /// \brief Get minimum shape dimensions.
/// ///
/// \return Minimum shape dimensions. /// \return Minimum shape dimensions.
const ShapeVector &min_shape() { return min_shape_; } const ShapeVector &min_shape() const { return min_shape_; }
/// \brief Get maximum shape dimensions. /// \brief Get maximum shape dimensions.
/// ///
/// \return Maximum shape dimensions. /// \return Maximum shape dimensions.
const ShapeVector &max_shape() { return max_shape_; } const ShapeVector &max_shape() const { return max_shape_; }
bool IsDynamic() const override { bool IsDynamic() const override {
return std::any_of(shape_.begin(), shape_.end(), [](int64_t s) { return s < 0; }); return std::any_of(shape_.begin(), shape_.end(), [](int64_t s) { return s < 0; });

View File

@ -22,12 +22,6 @@
namespace mindspore { namespace mindspore {
namespace abstract { namespace abstract {
const size_t stride_num_element = 2;
const size_t stride_start_idx = 2;
const size_t dilation_num_element = 2;
const size_t dilation_start_idx = 2;
const size_t padding_num_element = 4;
const size_t padding_start_idx = 0;
int64_t GetAndCheckFormat(const ValuePtr &value) { int64_t GetAndCheckFormat(const ValuePtr &value) {
int64_t data_format; int64_t data_format;
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format); bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format);
@ -81,7 +75,7 @@ AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &
auto pad_mode_ptr = primitive->GetAttr("pad_mode"); auto pad_mode_ptr = primitive->GetAttr("pad_mode");
if (pad_mode_ptr != nullptr) { if (pad_mode_ptr != nullptr) {
int64_t pad_mode; int64_t pad_mode;
const size_t middle = 2; const int64_t middle = 2;
CheckAndConvertUtils::GetPadModEnumValue(pad_mode_ptr, &pad_mode, true); CheckAndConvertUtils::GetPadModEnumValue(pad_mode_ptr, &pad_mode, true);
if (pad_mode == static_cast<int64_t>(PadMode::VALID)) { if (pad_mode == static_cast<int64_t>(PadMode::VALID)) {
padding = 0; padding = 0;
@ -97,7 +91,7 @@ AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &
MS_LOG(EXCEPTION) << "Unsupported pooling mode: " << mode << "."; MS_LOG(EXCEPTION) << "Unsupported pooling mode: " << mode << ".";
} }
} }
const size_t twice = 2; const int64_t twice = 2;
int64_t h_out = (((h_input + twice * padding - (window - 1)) - 1) / stride) + 1; int64_t h_out = (((h_input + twice * padding - (window - 1)) - 1) / stride) + 1;
int64_t w_out = (((w_input + twice * padding - (window - 1)) - 1) / stride) + 1; int64_t w_out = (((w_input + twice * padding - (window - 1)) - 1) / stride) + 1;
ShapeVector shape_out = {input_shape->shape()[0], input_shape->shape()[1], h_out, w_out}; ShapeVector shape_out = {input_shape->shape()[0], input_shape->shape()[1], h_out, w_out};
@ -202,50 +196,6 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit
return std::make_shared<AbstractTuple>(rets); return std::make_shared<AbstractTuple>(rets);
} }
void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pad_list, const int64_t x_h,
const int64_t x_w, const std::vector<int64_t> &kernel, const std::vector<int64_t> &stride,
const std::vector<int64_t> &dilation, const int64_t &pad_mode,
const std::vector<int64_t> &padding) {
const size_t middle = 2;
const size_t second_index = 2;
const size_t third_index = 3;
if (pad_mode == static_cast<int64_t>(PadMode::VALID)) {
output_hw->push_back(static_cast<int64_t>(std::ceil(((x_h * 1.0) - dilation[0] * (kernel[0] - 1)) / stride[0])));
output_hw->push_back(static_cast<int64_t>(std::ceil(((x_w * 1.0) - dilation[1] * (kernel[1] - 1)) / stride[1])));
const size_t nhwc = 4;
(void)pad_list->insert(pad_list->begin(), nhwc, 0);
} else if (pad_mode == static_cast<int64_t>(PadMode::SAME)) {
output_hw->push_back(static_cast<int64_t>(std::ceil((x_h * 1.0) / stride[0])));
output_hw->push_back(static_cast<int64_t>(std::ceil((x_w * 1.0) / stride[1])));
int64_t pad_needed_h = (output_hw->at(0) - 1) * stride[0] + dilation[0] * (kernel[0] - 1) + 1 - x_h;
pad_needed_h = std::max((int64_t)0, pad_needed_h);
pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_h / middle)));
pad_list->push_back(pad_needed_h - pad_list->at(0));
int64_t pad_needed_w = (output_hw->at(1) - 1) * stride[1] + dilation[1] * (kernel[1] - 1) + 1 - x_w;
pad_needed_w = std::max((int64_t)0, pad_needed_w);
pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_w / middle)));
pad_list->push_back(pad_needed_w - pad_list->at(middle));
} else if (pad_mode == static_cast<int64_t>(PadMode::PAD)) {
(void)pad_list->insert(pad_list->begin(), padding.begin(), padding.end());
output_hw->push_back(static_cast<int64_t>(std::floor(
1 + (((x_h * 1.0) + pad_list->at(0) + pad_list->at(1) - kernel[0]) - (kernel[0] - 1) * (dilation[0] - 1)) /
stride[0])));
output_hw->push_back(static_cast<int64_t>(
std::floor(1 + (((x_w * 1.0) + pad_list->at(second_index) + pad_list->at(third_index) - kernel[1]) -
(kernel[1] - 1) * (dilation[1] - 1)) /
stride[1])));
}
}
void CheckShape(const std::string &op_name, const ShapeVector &w_shape, const AbstractTensorPtr &input_w) {
ShapeVector w_min_shape = input_w->shape()->min_shape();
ShapeVector w_max_shape = input_w->shape()->max_shape();
CheckMinMaxShape(w_shape, &w_min_shape, &w_max_shape);
CheckShapeAnyAndPositive(op_name + " w_shape", w_shape);
CheckShapeAllPositive(op_name + " w_min_shape", w_min_shape);
CheckShapeAllPositive(op_name + " w_max_shape", w_max_shape);
}
AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
// Inputs: at least one tensor(y_backprop) // Inputs: at least one tensor(y_backprop)

View File

@ -18,9 +18,6 @@
#include "abstract/utils.h" #include "abstract/utils.h"
#include <string>
#include <sstream>
#include <memory>
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "utils/symbolic.h" #include "utils/symbolic.h"
#include "abstract/param_validator.h" #include "abstract/param_validator.h"

View File

@ -59,7 +59,7 @@ const std::string GetSubModuleName(SubModuleId module_id) { return sub_module_na
// export GetTimeString for all sub modules // export GetTimeString for all sub modules
std::string GetTimeString() { std::string GetTimeString() {
#define BUFLEN 80 constexpr auto BUFLEN = 80;
char buf[BUFLEN]; char buf[BUFLEN];
(void)memset(buf, '\0', BUFLEN); (void)memset(buf, '\0', BUFLEN);
#if defined(_WIN32) || defined(_WIN64) #if defined(_WIN32) || defined(_WIN64)

View File

@ -26,58 +26,6 @@
namespace mindspore { namespace mindspore {
namespace ops { namespace ops {
namespace {
int64_t Log2Ceil(int64_t length) {
if (length == 0) {
return -1;
}
int64_t floor = 0;
for (int64_t i = 4; i >= 0; --i) {
const int64_t shift = static_cast<int64_t>(1UL << static_cast<unsigned>(i));
int64_t tmp = SizeToLong(static_cast<uint64_t>(length) >> static_cast<uint64_t>(shift));
if (tmp != 0) {
length = tmp;
floor += shift;
}
}
return length == (length & ~(unsigned int)(length - 1)) ? floor : floor + 1;
}
int64_t GetFftLength(int64_t length) {
int64_t shift = Log2Ceil(length);
return SizeToLong(1UL << LongToSize(shift));
}
abstract::ShapePtr AudioSpectrogramInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
if (input_shape.size() != 2) {
MS_LOG(ERROR) << "input shape is error, which need to be 2 dimensions";
}
auto window_size = GetValue<int64_t>(primitive->GetAttr(kWindowSize));
if (window_size < 2) {
MS_LOG(ERROR) << "window size is too short, now is " << window_size;
}
auto stride_size = GetValue<int64_t>(primitive->GetAttr(kStride));
if (stride_size < 1) {
MS_LOG(ERROR) << "stride must be positive, now is " << stride_size;
}
std::vector<int64_t> infer_shape;
infer_shape.push_back(input_shape[1]);
int64_t sample_sub_window = input_shape[0] - window_size;
infer_shape.push_back(sample_sub_window < 0 ? 0 : 1 + sample_sub_window / stride_size);
int64_t fft_length = GetFftLength(window_size);
infer_shape.push_back(fft_length / 2 + 1);
MS_LOG(ERROR) << infer_shape;
return std::make_shared<abstract::Shape>(infer_shape);
}
TypePtr AudioSpectrogramInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
const size_t x_index = 0;
return CheckAndConvertUtils::GetTensorInputType(prim->name(), input_args, x_index);
}
} // namespace
void AudioSpectrogram::set_window_size(const int64_t window_size) { void AudioSpectrogram::set_window_size(const int64_t window_size) {
(void)this->AddAttr(kWindowSize, MakeValue(window_size)); (void)this->AddAttr(kWindowSize, MakeValue(window_size));
} }
@ -102,15 +50,6 @@ void AudioSpectrogram::Init(const int64_t window_size, const int64_t stride, con
this->set_stride(stride); this->set_stride(stride);
this->set_mag_square(mag_square); this->set_mag_square(mag_square);
} }
AbstractBasePtr AudioSpectrogramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
return std::make_shared<abstract::AbstractTensor>(AudioSpectrogramInferType(primitive, input_args),
AudioSpectrogramInferShape(primitive, input_args));
}
REGISTER_PRIMITIVE_C(kNameAudioSpectrogram, AudioSpectrogram); REGISTER_PRIMITIVE_C(kNameAudioSpectrogram, AudioSpectrogram);
} // namespace ops } // namespace ops
} // namespace mindspore } // namespace mindspore