code self check

static check clean
This commit is contained in:
chenfei 2022-08-24 17:01:15 +08:00
parent cb4ef260a7
commit cd3d6f3da0
8 changed files with 44 additions and 21 deletions

View File

@ -186,7 +186,7 @@ FuncGraphPtr ListReverse::GenerateFuncGraph(const abstract::AbstractBasePtrList
std::vector<AnfNodePtr> elems;
elems.push_back(NewValueNode(prim::kPrimMakeList));
for (int64_t i = arg_length - 1; i >= 0; --i) {
elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg_node, NewValueNode(SizeToLong(i))}));
elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg_node, NewValueNode(i)}));
}
ret->set_output(ret->NewCNode(elems));
@ -232,7 +232,9 @@ bool ListCount::ComparesTwoValues(const ValuePtr &count_value, const ValuePtr &l
if (count_value->isa<AnyValue>() || list_value->isa<AnyValue>()) {
MS_EXCEPTION(NotSupportError) << "The list count not support " << count_value->type_name() << " type now.";
} else if (count_value->isa<tensor::Tensor>()) {
return count_value->cast_ptr<tensor::Tensor>()->ValueEqual(*list_value->cast_ptr<tensor::Tensor>());
auto list_tensor_value = list_value->cast_ptr<tensor::Tensor>();
MS_EXCEPTION_IF_NULL(list_tensor_value);
return count_value->cast_ptr<tensor::Tensor>()->ValueEqual(*list_tensor_value);
} else {
return *count_value == *list_value;
}

View File

@ -93,7 +93,9 @@ class ListExtend : public MetaFuncGraph {
return os;
}
friend bool operator==(const ListExtend &lhs, const ListExtend &rhs) { return lhs.name_ == rhs.name_; }
void AddNodeToElems(const AbstractBasePtr &arg, const FuncGraphPtr &ret, std::vector<AnfNodePtr> *elems);
private:
static void AddNodeToElems(const AbstractBasePtr &arg, const FuncGraphPtr &ret, std::vector<AnfNodePtr> *elems);
};
using ListExtendPtr = std::shared_ptr<ListExtend>;

View File

@ -264,6 +264,7 @@ bool CheckPythonIsInstance(const py::object &x, const AbstractBasePtr &cmp, cons
}
bool CheckIsInstanceForFunc(const py::object &x_py_obj, const AbstractBasePtr &cmp, const py::module &mod) {
MS_EXCEPTION_IF_NULL(cmp);
if (cmp->isa<abstract::AbstractTuple>()) {
const auto &cmp_tuple_elements = cmp->cast<abstract::AbstractTuplePtr>()->elements();
return std::any_of(
@ -283,6 +284,7 @@ bool CheckIsInstanceForFunc(const py::object &x_py_obj, const AbstractBasePtr &c
}
bool CheckIsInstanceForSparse(const AbstractBasePtr &cmp, const std::string &target) {
MS_EXCEPTION_IF_NULL(cmp);
if (!cmp->isa<abstract::AbstractTuple>()) {
return cmp->ToString() == target;
}
@ -292,6 +294,7 @@ bool CheckIsInstanceForSparse(const AbstractBasePtr &cmp, const std::string &tar
}
py::object GetPrimitivePyObj(const abstract::PrimitiveAbstractClosurePtr &prim_abs) {
MS_EXCEPTION_IF_NULL(prim_abs);
auto prim = prim_abs->prim();
MS_EXCEPTION_IF_NULL(prim);
auto prim_signature = prim->cast<prim::DoSignaturePrimitivePtr>();
@ -303,12 +306,14 @@ py::object GetPrimitivePyObj(const abstract::PrimitiveAbstractClosurePtr &prim_a
}
py::object GetMsClassPyObj(const abstract::PartialAbstractClosurePtr &ms_class_abs) {
MS_EXCEPTION_IF_NULL(ms_class_abs);
const auto &ms_class_args = ms_class_abs->args();
if (ms_class_args.size() != 1) {
MS_LOG(EXCEPTION) << "When the first input to IsInstance is PartialAbstractClosure, its args size should be 1 but "
<< "got: " << ms_class_args.size() << ".";
}
auto first_arg = ms_class_args[0];
MS_EXCEPTION_IF_NULL(first_arg);
auto arg_type = first_arg->BuildType();
auto arg_type_id = arg_type->type_id();
if (arg_type_id != kObjectTypeClass) {
@ -321,6 +326,7 @@ py::object GetMsClassPyObj(const abstract::PartialAbstractClosurePtr &ms_class_a
}
bool CheckCmpValid(const AbstractBasePtr &cmp) {
MS_EXCEPTION_IF_NULL(cmp);
if (cmp->isa<abstract::AbstractSequence>()) {
if (!cmp->isa<abstract::AbstractTuple>()) {
return false;
@ -352,7 +358,7 @@ bool CheckCmpValid(const AbstractBasePtr &cmp) {
}
return cmp_type_id == kMetaTypeTypeType;
}
return std::find(kSparsePrimStr.begin(), kSparsePrimStr.end(), cmp->ToString()) != kSparsePrimStr.end();
return std::find(kSparsePrimStr.cbegin(), kSparsePrimStr.cend(), cmp->ToString()) != kSparsePrimStr.cend();
}
AbstractBasePtr InferImplIsInstance(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
@ -374,6 +380,7 @@ AbstractBasePtr InferImplIsInstance(const AnalysisEnginePtr &, const PrimitivePt
}
// x is Cell object.
MS_EXCEPTION_IF_NULL(x);
if (x->isa<abstract::FuncGraphAbstractClosure>()) {
auto x_fg = x->cast<abstract::FuncGraphAbstractClosurePtr>()->func_graph();
MS_EXCEPTION_IF_NULL(x_fg);
@ -611,12 +618,14 @@ AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr
MS_LOG(INFO) << "The shape of dividend:" << shape_x->ToString() << ", the shape of divisor:" << div_shp->ToString();
auto div_shp_value = div_shp->BuildValue();
MS_EXCEPTION_IF_NULL(div_shp_value);
if (div_shp_value->isa<AnyValue>()) {
MS_LOG(EXCEPTION) << "The 'tuple_div' operator shape's data field can't be anything, but got "
<< args_spec_list[0]->ToString() << ".";
}
auto shape_x_value = shape_x->BuildValue();
MS_EXCEPTION_IF_NULL(shape_x_value);
if (shape_x_value->isa<AnyValue>()) {
MS_LOG(EXCEPTION) << "The 'tuple_div' operator shape's data field can't be anything, but got "
<< args_spec_list[1]->ToString() << ".";
@ -626,12 +635,16 @@ AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr
MS_LOG(EXCEPTION) << "The size of inputs of 'tuple_div' operator must be the same, but the size of divisor tuple is"
<< " " << div_shp->size() << ", the size of dividend tuple is " << shape_x->size() << ".";
}
auto shape_x_data = shape_x_value->cast<ValueTuplePtr>()->value();
auto div_shape_data = div_shp_value->cast<ValueTuplePtr>()->value();
auto shape_x_tuple_value = shape_x_value->cast<ValueTuplePtr>();
auto div_shape_tuple_value = div_shp_value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(shape_x_tuple_value);
MS_EXCEPTION_IF_NULL(div_shape_tuple_value);
auto shape_x_data = shape_x_tuple_value->value();
auto div_shape_data = div_shape_tuple_value->value();
AbstractBasePtrList values;
for (size_t i = 0; i < div_shape_data.size(); i++) {
MS_EXCEPTION_IF_NULL(div_shape_data[i]);
if (div_shape_data[i]->cast<Int64ImmPtr>() == nullptr) {
auto value_type = div_shape_data[i]->type();
std::string str_type;
@ -643,8 +656,8 @@ AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr
MS_LOG(EXCEPTION) << "The data type of inputs of 'tuple_div' operator should be an int64 number, but got a "
<< str_type << " number " << div_shape_data[i]->ToString() << ".";
}
int64_t shapex_value = GetValue<int64_t>(shape_x_data[i]);
int64_t div_value = GetValue<int64_t>(div_shape_data[i]);
auto shapex_value = GetValue<int64_t>(shape_x_data[i]);
auto div_value = GetValue<int64_t>(div_shape_data[i]);
MS_LOG(DEBUG) << "div_shp_shape data shapex_value :" << shapex_value << " div_value: " << div_value;
if (div_value == 0) {
MS_LOG(EXCEPTION) << "The divisor value should not be 0!";

View File

@ -794,6 +794,10 @@ AnfNodePtr KPynativeCellImpl::BuildKNodeForCNodeInput(const PynativeAdjointPtr &
if (input_index < 1) {
MS_EXCEPTION(ValueError) << "The input_index is smaller than 1.";
}
if (input_index > cnode_adjoint->op_args().size()) {
MS_EXCEPTION(ValueError) << "The input_index: " << input_index
<< " out of range:" << cnode_adjoint->op_args().size();
}
return NewValueNode(cnode_adjoint->op_args()[input_index - 1]);
}
}

View File

@ -212,7 +212,7 @@ class ChoicePartialEliminater : public AnfVisitor {
}
private:
std::vector<AnfNodePtr> ArgsUnion(const std::vector<AnfNodePtrList> args_list) {
static std::vector<AnfNodePtr> ArgsUnion(const std::vector<AnfNodePtrList> args_list) {
std::set<AnfNodePtr> no_monad_args;
std::set<AnfNodePtr> monad_args;
for (const auto &args : args_list) {
@ -230,8 +230,8 @@ class ChoicePartialEliminater : public AnfVisitor {
return union_args;
}
HashMap<FuncGraphPtr, HashMap<AnfNodePtr, size_t>> GenOldArgsIndexes(const AnfNodePtrList &fg_list,
const std::vector<AnfNodePtrList> &args_list) {
static HashMap<FuncGraphPtr, HashMap<AnfNodePtr, size_t>> GenOldArgsIndexes(
const AnfNodePtrList &fg_list, const std::vector<AnfNodePtrList> &args_list) {
HashMap<FuncGraphPtr, HashMap<AnfNodePtr, size_t>> old_args_indexes;
for (size_t i = 0; i < fg_list.size(); ++i) {
const auto func_graph = GetValueNode<FuncGraphPtr>(fg_list[i]);
@ -239,16 +239,16 @@ class ChoicePartialEliminater : public AnfVisitor {
const auto &args = args_list[i];
HashMap<AnfNodePtr, size_t> args_indexes;
size_t arg_index = 0;
std::for_each(args.cbegin(), args.cend(), [&args_indexes, &arg_index](const AnfNodePtr &arg) {
for (const auto &arg : args) {
(void)args_indexes.emplace(arg, arg_index++);
});
}
old_args_indexes[func_graph] = args_indexes;
}
return old_args_indexes;
}
AnfNodePtr GetParameterByArg(const HashMap<FuncGraphPtr, HashMap<AnfNodePtr, size_t>> &all_old_args_index_map,
const AnfNodePtr &arg) {
static AnfNodePtr GetParameterByArg(const HashMap<FuncGraphPtr, HashMap<AnfNodePtr, size_t>> &all_old_args_index_map,
const AnfNodePtr &arg) {
MS_LOG(DEBUG) << "Get parameter by arg:" << arg->DebugString();
for (const auto &[fg, old_args_index] : all_old_args_index_map) {
auto it = old_args_index.find(arg);
@ -352,7 +352,7 @@ class SwitchPartialEliminater : public ChoicePartialEliminater {
// Create switch call.
TraceGuard guard2(std::make_shared<TraceCopy>(switch_call->debug_info()));
AnfNodePtrList switch_call_inputs{new_switch_cnode};
switch_call_inputs.insert(switch_call_inputs.end(), new_args.begin(), new_args.end());
(void)switch_call_inputs.insert(switch_call_inputs.end(), new_args.begin(), new_args.end());
const auto new_call_node = switch_call->func_graph()->NewCNode(std::move(switch_call_inputs));
new_call_node->set_abstract(switch_call->abstract());
return new_call_node;

View File

@ -141,8 +141,8 @@ inline std::tuple<size_t, size_t, size_t> CalPosition(const OffsetIndex &offset_
inline InputXIndex CalInputXIndex(const OffsetIndex &offset_index, const DeformableOffsetGradDims &dims) {
InputXIndex input_x_index;
input_x_index.i = -1.0 * SizeToFloat(dims.pad_top);
input_x_index.j = -1.0 * SizeToFloat(dims.pad_left);
input_x_index.i = -1.0f * SizeToFloat(dims.pad_top);
input_x_index.j = -1.0f * SizeToFloat(dims.pad_left);
input_x_index.i += SizeToFloat(offset_index.offset_i * dims.stride_h + offset_index.kernel_i * dims.dilation_h);
input_x_index.j += SizeToFloat(offset_index.offset_j * dims.stride_w + offset_index.kernel_j * dims.dilation_w);
return input_x_index;

View File

@ -45,13 +45,13 @@ bool Compute(const ComputeParams<T, S> *params, const size_t start, const size_t
std::vector<size_t> local_indices;
for (size_t j = 0; j < params->indices_unit_rank_; ++j) {
auto index = indices[i * params->indices_unit_rank_ + j];
(void)local_indices.emplace_back(IntToSize(index));
(void)local_indices.emplace_back(LongToSize(index));
if (index < 0) {
MS_LOG(ERROR) << "For '" << kKernelName
<< "', each element in 'indices' must be greater than or equal to 0, but got " << index;
return false;
}
offset += IntToSize(index) * out_strides->at(j) * params->unit_size_;
offset += LongToSize(index) * out_strides->at(j) * params->unit_size_;
}
if (offset * sizeof(T) > params->x_mem_size_) {
MS_LOG(ERROR) << "For '" << kKernelName

View File

@ -56,6 +56,7 @@ Primitive::Primitive(const Primitive &prim)
prim_type_(prim.prim_type_),
record_evaluate_add_attr_(false),
is_const_prim_(false),
const_input_indexes_(prim.const_input_indexes_),
id_(prim.id_) {}
Primitive &Primitive::operator=(const Primitive &other) {
@ -72,6 +73,7 @@ Primitive &Primitive::operator=(const Primitive &other) {
record_evaluate_add_attr_ = false;
is_const_prim_ = false;
id_ = other.id_;
const_input_indexes_ = other.const_input_indexes_;
return *this;
}