fix code check
This commit is contained in:
parent
16bb659a6b
commit
ffe4393e95
|
@ -24,6 +24,7 @@
|
|||
"mindspore/mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.cc" "useStlAlgorithm"
|
||||
"mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/cast_gpu_kernel.cc" "unknownMacro"
|
||||
"mindspore/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc" "nullPointerArithmeticRedundantCheck"
|
||||
"mindspore/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc" "containerOutOfBounds"
|
||||
|
||||
# MindData
|
||||
"mindspore/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc" "useStlAlgorithm"
|
||||
|
|
|
@ -826,7 +826,7 @@ REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr);
|
|||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Shard, prim::kPrimShard, InferImplShard, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs,
|
||||
InferImplBroadcastGradientArgs, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(MakeSiice, prim::kPrimMakeSlice, InferImplMakeSlice, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(MakeSlice, prim::kPrimMakeSlice, InferImplMakeSlice, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(SliceGetItem, prim::kPrimSliceGetItem, InferImplSliceGetItem, nullptr);
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -84,9 +84,9 @@ class Resource : public ResourceBase {
|
|||
loop_size_ = size;
|
||||
}
|
||||
void set_is_load(bool flag) { is_load_ = flag; }
|
||||
bool is_load() { return is_load_; }
|
||||
bool vm_loop_flag() { return vm_loop_flag_; }
|
||||
int64_t loop_size() { return loop_size_; }
|
||||
bool is_load() const { return is_load_; }
|
||||
bool vm_loop_flag() const { return vm_loop_flag_; }
|
||||
int64_t loop_size() const { return loop_size_; }
|
||||
|
||||
void set_layout_map(const LayoutMap &layout_map) { layout_map_ = layout_map; }
|
||||
const LayoutMap &get_layout_map() const { return layout_map_; }
|
||||
|
|
|
@ -294,7 +294,7 @@ class AsyncInferTask {
|
|||
static AsyncInferTaskPtr MakeShared(const AsyncAbstractPtr &abstract, const std::string &thread = "") {
|
||||
std::string thread_id = thread;
|
||||
if (thread_id == "") {
|
||||
thread_id = AnalysisSchedule::GetInstance().thread_id();
|
||||
thread_id = AnalysisSchedule::thread_id();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
auto ret = std::make_shared<AsyncInferTask>(thread_id, abstract);
|
||||
|
@ -387,10 +387,10 @@ class EvaluatorCacheMgr {
|
|||
~EvaluatorCacheMgr() = default;
|
||||
|
||||
void Clear() { eval_result_cache_.clear(); }
|
||||
const EvalResultCache &GetCache() { return eval_result_cache_; }
|
||||
const EvalResultCache &GetCache() const { return eval_result_cache_; }
|
||||
EvalResultPtr GetValue(const AbstractBasePtrList &key) { return eval_result_cache_.get(key); }
|
||||
void SetValue(const AbstractBasePtrList &key, const EvalResultPtr &arg) { eval_result_cache_.set(key, arg); }
|
||||
size_t GetSize() { return eval_result_cache_.size(); }
|
||||
size_t GetSize() const { return eval_result_cache_.size(); }
|
||||
|
||||
private:
|
||||
EvalResultCache eval_result_cache_;
|
||||
|
|
|
@ -616,7 +616,6 @@ class SideEffectFinder {
|
|||
if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) {
|
||||
if (tuple_indexes->empty()) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected make_tuple: " << cnode->DebugString(2);
|
||||
return {EffectInfo::kDetected, false, false, false};
|
||||
}
|
||||
// Pop out tuple index.
|
||||
auto top_index = tuple_indexes->top();
|
||||
|
|
|
@ -73,7 +73,7 @@ class Evaluator : public Base {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
auto is_abstract = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](auto &arg) {
|
||||
auto is_abstract = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](auto &arg) -> bool {
|
||||
if (arg->BuildType()->type_id() == kObjectTypeUndeterminedType) {
|
||||
return true;
|
||||
}
|
||||
|
@ -94,7 +94,6 @@ class Evaluator : public Base {
|
|||
|
||||
EvaluatorCacheMgrPtr evaluator_cache_mgr() const { return evaluator_cache_mgr_; }
|
||||
EvaluatorAttrCachePtr attr_cache() const { return attr_cache_; }
|
||||
|
||||
const std::recursive_timed_mutex &eval_lock() const { return eval_lock_; }
|
||||
|
||||
protected:
|
||||
|
|
|
@ -89,7 +89,7 @@ class UniformPrimEvaluator final : public TrivialPrimEvaluator {
|
|||
}
|
||||
}
|
||||
}
|
||||
~UniformPrimEvaluator() override = default;
|
||||
~UniformPrimEvaluator() override { impl_ = nullptr; };
|
||||
MS_DECLARE_PARENT(UniformPrimEvaluator, TrivialPrimEvaluator);
|
||||
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
|
||||
|
@ -122,7 +122,7 @@ class DoSignatureEvaluator final : public Evaluator {
|
|||
~DoSignatureEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(DoSignatureEvaluator, Evaluator);
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
|
||||
const AnfNodeConfigPtr &out_config = nullptr) override;
|
||||
const AnfNodeConfigPtr &out_config) override;
|
||||
|
||||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
|
||||
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
|
||||
|
@ -138,7 +138,7 @@ class UnpackGraphEvaluator final : public Evaluator {
|
|||
~UnpackGraphEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(UnpackGraphEvaluator, Evaluator);
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
|
||||
const AnfNodeConfigPtr &out_config = nullptr) override;
|
||||
const AnfNodeConfigPtr &out_config) override;
|
||||
|
||||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
|
||||
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
|
||||
|
@ -155,7 +155,7 @@ class MixedPrecisionCastEvaluator final : public Evaluator {
|
|||
~MixedPrecisionCastEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(MixedPrecisionCastEvaluator, Evaluator);
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
|
||||
const AnfNodeConfigPtr &out_config = nullptr) override;
|
||||
const AnfNodeConfigPtr &out_config) override;
|
||||
|
||||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
|
||||
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
|
||||
|
|
|
@ -27,7 +27,6 @@
|
|||
namespace mindspore {
|
||||
namespace abstract {
|
||||
class StackFrame;
|
||||
using StackFramePtr = std::shared_ptr<StackFrame>;
|
||||
using EvaluatorWeakPtr = std::weak_ptr<Evaluator>;
|
||||
using BaseFuncGraphEvaluatorPtr = std::shared_ptr<BaseFuncGraphEvaluator>;
|
||||
|
||||
|
@ -65,7 +64,7 @@ class StackFrame final : public Base {
|
|||
// Return back from branch func graph.
|
||||
void Back(const AnalysisEnginePtr &engine, const StackFramePtr &last_stack_frame, const EvalResultPtr &eval_result);
|
||||
|
||||
bool Done() { return done_; }
|
||||
bool Done() const { return done_; }
|
||||
|
||||
AnfNodePtr &CurrentNode() {
|
||||
if (slot_index_ >= node_slots_.size()) {
|
||||
|
@ -92,7 +91,7 @@ class StackFrame final : public Base {
|
|||
AnalysisContextPtr current_context() const { return current_context_; }
|
||||
AnalysisContextPtr parent_context() const { return parent_context_; }
|
||||
|
||||
const AbstractBasePtrList &args_abs_list() { return args_abs_list_; }
|
||||
const AbstractBasePtrList &args_abs_list() const { return args_abs_list_; }
|
||||
void set_args_abs_list(const AbstractBasePtrList &&args_abs_list) { args_abs_list_ = args_abs_list; }
|
||||
|
||||
std::string ToString() const override {
|
||||
|
|
|
@ -63,7 +63,7 @@ using AttrValueMapPtr = std::shared_ptr<AttrValueMap>;
|
|||
// the class to save evaluated result: abstract value and modified attribute
|
||||
class EvalResult : public Base {
|
||||
public:
|
||||
EvalResult(AbstractBasePtr abs, AttrValueMapPtr attr) : abstract_(abs), attribute_(attr) {}
|
||||
EvalResult(const AbstractBasePtr &abs, const AttrValueMapPtr &attr) : abstract_(abs), attribute_(attr) {}
|
||||
~EvalResult() override = default;
|
||||
MS_DECLARE_PARENT(EvalResult, Base);
|
||||
AbstractBasePtr abstract() { return abstract_; }
|
||||
|
@ -261,7 +261,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
forward_count_ = 0;
|
||||
enable_recursive_eval_ = (common::GetEnv("MS_DEV_RECURSIVE_EVAL") == "1");
|
||||
}
|
||||
~AnalysisEngine() = default;
|
||||
virtual ~AnalysisEngine() = default;
|
||||
|
||||
// func_graph: The func_graph to analyze.
|
||||
// args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase.
|
||||
|
|
|
@ -71,7 +71,7 @@ class PrimitivePy : public Primitive {
|
|||
py::dict RunInfer(const py::tuple &args);
|
||||
void RunCheck(const py::tuple &args);
|
||||
py::object RunInferValue(const py::tuple &args);
|
||||
bool HasPyObj() { return python_obj_.operator bool(); }
|
||||
bool HasPyObj() const { return python_obj_.operator bool(); }
|
||||
PrimitivePtr Clone() override;
|
||||
PrimitivePyAdapterPtr adapter() const { return adapter_; }
|
||||
void set_bprop_cls_name(const std::string &name) { bprop_cls_name_ = name; }
|
||||
|
@ -102,7 +102,7 @@ class PrimitivePyAdapter {
|
|||
void set_hook(const py::function &hook);
|
||||
void set_instance_name(const std::string &s);
|
||||
void set_attached_primitive(const PrimitivePyPtr &prim);
|
||||
PrimitivePyPtr attached_primitive() { return attached_primitive_.lock(); }
|
||||
PrimitivePyPtr attached_primitive() const { return attached_primitive_.lock(); }
|
||||
std::string name() const { return name_; }
|
||||
void set_name(const std::string &name) { name_ = name; }
|
||||
const bool parse_info_ = true;
|
||||
|
|
|
@ -16,12 +16,6 @@
|
|||
|
||||
#include "pybind_api/ir/tensor_py.h"
|
||||
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <complex>
|
||||
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/shape_utils.h"
|
||||
|
@ -125,6 +119,8 @@ static std::string GetPyTypeFormat(TypeId data_type) {
|
|||
return py::format_descriptor<std::complex<float>>::format();
|
||||
case TypeId::kNumberTypeComplex128:
|
||||
return py::format_descriptor<std::complex<double>>::format();
|
||||
case TypeId::kMetaTypeType:
|
||||
case TypeId::kMetaTypeEllipsis:
|
||||
default:
|
||||
MS_LOG(WARNING) << "Unsupported DataType " << data_type << ".";
|
||||
return "";
|
||||
|
@ -182,7 +178,7 @@ class TensorDataNumpy : public TensorData {
|
|||
}
|
||||
|
||||
private:
|
||||
void *buffer_data() { return buffer_.ptr; }
|
||||
void *buffer_data() const { return buffer_.ptr; }
|
||||
|
||||
// The internal buffer.
|
||||
py::buffer_info buffer_;
|
||||
|
|
|
@ -15,8 +15,6 @@
|
|||
*/
|
||||
#include "pybind_api/random_normal/random_cpu_kernel.h"
|
||||
|
||||
#include <thread>
|
||||
#include <memory>
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "ir/tensor.h"
|
||||
|
||||
|
|
|
@ -1138,12 +1138,11 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
|
|||
parser->Parse(control_nodes_, graphs, device_contexts, root_graph, func_graph_to_kernel_graphs);
|
||||
|
||||
runtime::KernelMapPosition outputs_order;
|
||||
size_t outputs_num = 0;
|
||||
const auto &root_output =
|
||||
AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0, false, {prim::kPrimTupleGetItem}).first;
|
||||
size_t position = 0;
|
||||
auto outputs = AnfAlgo::GetAllOutputWithIndex(root_output);
|
||||
outputs_num = outputs.size();
|
||||
size_t outputs_num = outputs.size();
|
||||
for (const auto &output : outputs) {
|
||||
if (outputs_order.count(output) == 0) {
|
||||
outputs_order[output] = {position++};
|
||||
|
|
|
@ -134,7 +134,7 @@ void FinalVM::Popsp() {
|
|||
int64_t sp = retsp_.top();
|
||||
MS_LOG(DEBUG) << "Current sp:" << sp_ << ", before sp:" << sp << ", " << sp_ - sp;
|
||||
if (sp_ >= sp) {
|
||||
Pop(sp_ - sp + 1);
|
||||
Pop((sp_ - sp) + 1);
|
||||
retsp_.pop();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Stack point sp_:" << sp << " must be bigger than sp:" << sp_;
|
||||
|
|
|
@ -268,14 +268,12 @@ class MS_CORE_API PartialAbstractClosure final : public AbstractFuncAtom {
|
|||
/// \brief Get the pre-provided arguments.
|
||||
///
|
||||
/// \return The pre-provided arguments.
|
||||
const AbstractBasePtrList &args() { return args_spec_list_; }
|
||||
|
||||
ValuePtr RealBuildValue() const override { return fn_->BuildValue(); }
|
||||
const AbstractBasePtrList &args() const { return args_spec_list_; }
|
||||
|
||||
/// \brief Get the CNode this PartialAbstractClosure evaluated from.
|
||||
///
|
||||
/// \return The CNode this PartialAbstractClosure evaluated from.
|
||||
AnfNodePtr node() { return node_.lock(); }
|
||||
AnfNodePtr node() const { return node_.lock(); }
|
||||
|
||||
/// \brief Set the CNode this PartialAbstractClosure evaluated from.
|
||||
///
|
||||
|
@ -292,6 +290,9 @@ class MS_CORE_API PartialAbstractClosure final : public AbstractFuncAtom {
|
|||
|
||||
std::string ToString() const override;
|
||||
|
||||
protected:
|
||||
ValuePtr RealBuildValue() const override { return fn_->BuildValue(); }
|
||||
|
||||
private:
|
||||
AbstractFuncAtomPtr fn_;
|
||||
AbstractBasePtrList args_spec_list_;
|
||||
|
|
|
@ -147,7 +147,7 @@ class MS_CORE_API AbstractBase : public Base {
|
|||
/// \brief Set the function, which prints the debug info.
|
||||
///
|
||||
/// \param[in] trace_node_provider The function.
|
||||
static void set_trace_node_provider(TraceNodeProvider trace_node_provider) {
|
||||
static void set_trace_node_provider(const TraceNodeProvider &trace_node_provider) {
|
||||
trace_node_provider_ = trace_node_provider;
|
||||
}
|
||||
|
||||
|
|
|
@ -18,11 +18,6 @@
|
|||
|
||||
#include "abstract/dshape.h"
|
||||
|
||||
#include <exception>
|
||||
#include <iostream>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
namespace {
|
||||
|
|
|
@ -163,17 +163,17 @@ class MS_CORE_API Shape final : public BaseShape {
|
|||
/// \brief Get shape dimensions.
|
||||
///
|
||||
/// \return Shape dimensions.
|
||||
const ShapeVector &shape() { return shape_; }
|
||||
const ShapeVector &shape() const { return shape_; }
|
||||
|
||||
/// \brief Get minimum shape dimensions.
|
||||
///
|
||||
/// \return Minimum shape dimensions.
|
||||
const ShapeVector &min_shape() { return min_shape_; }
|
||||
const ShapeVector &min_shape() const { return min_shape_; }
|
||||
|
||||
/// \brief Get maximum shape dimensions.
|
||||
///
|
||||
/// \return Maximum shape dimensions.
|
||||
const ShapeVector &max_shape() { return max_shape_; }
|
||||
const ShapeVector &max_shape() const { return max_shape_; }
|
||||
|
||||
bool IsDynamic() const override {
|
||||
return std::any_of(shape_.begin(), shape_.end(), [](int64_t s) { return s < 0; });
|
||||
|
|
|
@ -22,12 +22,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
const size_t stride_num_element = 2;
|
||||
const size_t stride_start_idx = 2;
|
||||
const size_t dilation_num_element = 2;
|
||||
const size_t dilation_start_idx = 2;
|
||||
const size_t padding_num_element = 4;
|
||||
const size_t padding_start_idx = 0;
|
||||
int64_t GetAndCheckFormat(const ValuePtr &value) {
|
||||
int64_t data_format;
|
||||
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format);
|
||||
|
@ -81,7 +75,7 @@ AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &
|
|||
auto pad_mode_ptr = primitive->GetAttr("pad_mode");
|
||||
if (pad_mode_ptr != nullptr) {
|
||||
int64_t pad_mode;
|
||||
const size_t middle = 2;
|
||||
const int64_t middle = 2;
|
||||
CheckAndConvertUtils::GetPadModEnumValue(pad_mode_ptr, &pad_mode, true);
|
||||
if (pad_mode == static_cast<int64_t>(PadMode::VALID)) {
|
||||
padding = 0;
|
||||
|
@ -97,7 +91,7 @@ AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &
|
|||
MS_LOG(EXCEPTION) << "Unsupported pooling mode: " << mode << ".";
|
||||
}
|
||||
}
|
||||
const size_t twice = 2;
|
||||
const int64_t twice = 2;
|
||||
int64_t h_out = (((h_input + twice * padding - (window - 1)) - 1) / stride) + 1;
|
||||
int64_t w_out = (((w_input + twice * padding - (window - 1)) - 1) / stride) + 1;
|
||||
ShapeVector shape_out = {input_shape->shape()[0], input_shape->shape()[1], h_out, w_out};
|
||||
|
@ -202,50 +196,6 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit
|
|||
return std::make_shared<AbstractTuple>(rets);
|
||||
}
|
||||
|
||||
void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pad_list, const int64_t x_h,
|
||||
const int64_t x_w, const std::vector<int64_t> &kernel, const std::vector<int64_t> &stride,
|
||||
const std::vector<int64_t> &dilation, const int64_t &pad_mode,
|
||||
const std::vector<int64_t> &padding) {
|
||||
const size_t middle = 2;
|
||||
const size_t second_index = 2;
|
||||
const size_t third_index = 3;
|
||||
if (pad_mode == static_cast<int64_t>(PadMode::VALID)) {
|
||||
output_hw->push_back(static_cast<int64_t>(std::ceil(((x_h * 1.0) - dilation[0] * (kernel[0] - 1)) / stride[0])));
|
||||
output_hw->push_back(static_cast<int64_t>(std::ceil(((x_w * 1.0) - dilation[1] * (kernel[1] - 1)) / stride[1])));
|
||||
const size_t nhwc = 4;
|
||||
(void)pad_list->insert(pad_list->begin(), nhwc, 0);
|
||||
} else if (pad_mode == static_cast<int64_t>(PadMode::SAME)) {
|
||||
output_hw->push_back(static_cast<int64_t>(std::ceil((x_h * 1.0) / stride[0])));
|
||||
output_hw->push_back(static_cast<int64_t>(std::ceil((x_w * 1.0) / stride[1])));
|
||||
int64_t pad_needed_h = (output_hw->at(0) - 1) * stride[0] + dilation[0] * (kernel[0] - 1) + 1 - x_h;
|
||||
pad_needed_h = std::max((int64_t)0, pad_needed_h);
|
||||
pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_h / middle)));
|
||||
pad_list->push_back(pad_needed_h - pad_list->at(0));
|
||||
int64_t pad_needed_w = (output_hw->at(1) - 1) * stride[1] + dilation[1] * (kernel[1] - 1) + 1 - x_w;
|
||||
pad_needed_w = std::max((int64_t)0, pad_needed_w);
|
||||
pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_w / middle)));
|
||||
pad_list->push_back(pad_needed_w - pad_list->at(middle));
|
||||
} else if (pad_mode == static_cast<int64_t>(PadMode::PAD)) {
|
||||
(void)pad_list->insert(pad_list->begin(), padding.begin(), padding.end());
|
||||
output_hw->push_back(static_cast<int64_t>(std::floor(
|
||||
1 + (((x_h * 1.0) + pad_list->at(0) + pad_list->at(1) - kernel[0]) - (kernel[0] - 1) * (dilation[0] - 1)) /
|
||||
stride[0])));
|
||||
output_hw->push_back(static_cast<int64_t>(
|
||||
std::floor(1 + (((x_w * 1.0) + pad_list->at(second_index) + pad_list->at(third_index) - kernel[1]) -
|
||||
(kernel[1] - 1) * (dilation[1] - 1)) /
|
||||
stride[1])));
|
||||
}
|
||||
}
|
||||
|
||||
void CheckShape(const std::string &op_name, const ShapeVector &w_shape, const AbstractTensorPtr &input_w) {
|
||||
ShapeVector w_min_shape = input_w->shape()->min_shape();
|
||||
ShapeVector w_max_shape = input_w->shape()->max_shape();
|
||||
CheckMinMaxShape(w_shape, &w_min_shape, &w_max_shape);
|
||||
CheckShapeAnyAndPositive(op_name + " w_shape", w_shape);
|
||||
CheckShapeAllPositive(op_name + " w_min_shape", w_min_shape);
|
||||
CheckShapeAllPositive(op_name + " w_max_shape", w_max_shape);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: at least one tensor(y_backprop)
|
||||
|
|
|
@ -18,9 +18,6 @@
|
|||
|
||||
#include "abstract/utils.h"
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <memory>
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "abstract/param_validator.h"
|
||||
|
|
|
@ -59,7 +59,7 @@ const std::string GetSubModuleName(SubModuleId module_id) { return sub_module_na
|
|||
|
||||
// export GetTimeString for all sub modules
|
||||
std::string GetTimeString() {
|
||||
#define BUFLEN 80
|
||||
constexpr auto BUFLEN = 80;
|
||||
char buf[BUFLEN];
|
||||
(void)memset(buf, '\0', BUFLEN);
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
|
|
|
@ -26,58 +26,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
int64_t Log2Ceil(int64_t length) {
|
||||
if (length == 0) {
|
||||
return -1;
|
||||
}
|
||||
int64_t floor = 0;
|
||||
for (int64_t i = 4; i >= 0; --i) {
|
||||
const int64_t shift = static_cast<int64_t>(1UL << static_cast<unsigned>(i));
|
||||
int64_t tmp = SizeToLong(static_cast<uint64_t>(length) >> static_cast<uint64_t>(shift));
|
||||
if (tmp != 0) {
|
||||
length = tmp;
|
||||
floor += shift;
|
||||
}
|
||||
}
|
||||
return length == (length & ~(unsigned int)(length - 1)) ? floor : floor + 1;
|
||||
}
|
||||
|
||||
int64_t GetFftLength(int64_t length) {
|
||||
int64_t shift = Log2Ceil(length);
|
||||
return SizeToLong(1UL << LongToSize(shift));
|
||||
}
|
||||
|
||||
abstract::ShapePtr AudioSpectrogramInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
if (input_shape.size() != 2) {
|
||||
MS_LOG(ERROR) << "input shape is error, which need to be 2 dimensions";
|
||||
}
|
||||
auto window_size = GetValue<int64_t>(primitive->GetAttr(kWindowSize));
|
||||
if (window_size < 2) {
|
||||
MS_LOG(ERROR) << "window size is too short, now is " << window_size;
|
||||
}
|
||||
auto stride_size = GetValue<int64_t>(primitive->GetAttr(kStride));
|
||||
if (stride_size < 1) {
|
||||
MS_LOG(ERROR) << "stride must be positive, now is " << stride_size;
|
||||
}
|
||||
std::vector<int64_t> infer_shape;
|
||||
infer_shape.push_back(input_shape[1]);
|
||||
int64_t sample_sub_window = input_shape[0] - window_size;
|
||||
infer_shape.push_back(sample_sub_window < 0 ? 0 : 1 + sample_sub_window / stride_size);
|
||||
int64_t fft_length = GetFftLength(window_size);
|
||||
infer_shape.push_back(fft_length / 2 + 1);
|
||||
MS_LOG(ERROR) << infer_shape;
|
||||
return std::make_shared<abstract::Shape>(infer_shape);
|
||||
}
|
||||
|
||||
TypePtr AudioSpectrogramInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const size_t x_index = 0;
|
||||
return CheckAndConvertUtils::GetTensorInputType(prim->name(), input_args, x_index);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void AudioSpectrogram::set_window_size(const int64_t window_size) {
|
||||
(void)this->AddAttr(kWindowSize, MakeValue(window_size));
|
||||
}
|
||||
|
@ -102,15 +50,6 @@ void AudioSpectrogram::Init(const int64_t window_size, const int64_t stride, con
|
|||
this->set_stride(stride);
|
||||
this->set_mag_square(mag_square);
|
||||
}
|
||||
|
||||
AbstractBasePtr AudioSpectrogramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(AudioSpectrogramInferType(primitive, input_args),
|
||||
AudioSpectrogramInferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameAudioSpectrogram, AudioSpectrogram);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue