forked from mindspore-Ecosystem/mindspore
parent
cb4ef260a7
commit
cd3d6f3da0
|
@ -186,7 +186,7 @@ FuncGraphPtr ListReverse::GenerateFuncGraph(const abstract::AbstractBasePtrList
|
|||
std::vector<AnfNodePtr> elems;
|
||||
elems.push_back(NewValueNode(prim::kPrimMakeList));
|
||||
for (int64_t i = arg_length - 1; i >= 0; --i) {
|
||||
elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg_node, NewValueNode(SizeToLong(i))}));
|
||||
elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg_node, NewValueNode(i)}));
|
||||
}
|
||||
|
||||
ret->set_output(ret->NewCNode(elems));
|
||||
|
@ -232,7 +232,9 @@ bool ListCount::ComparesTwoValues(const ValuePtr &count_value, const ValuePtr &l
|
|||
if (count_value->isa<AnyValue>() || list_value->isa<AnyValue>()) {
|
||||
MS_EXCEPTION(NotSupportError) << "The list count not support " << count_value->type_name() << " type now.";
|
||||
} else if (count_value->isa<tensor::Tensor>()) {
|
||||
return count_value->cast_ptr<tensor::Tensor>()->ValueEqual(*list_value->cast_ptr<tensor::Tensor>());
|
||||
auto list_tensor_value = list_value->cast_ptr<tensor::Tensor>();
|
||||
MS_EXCEPTION_IF_NULL(list_tensor_value);
|
||||
return count_value->cast_ptr<tensor::Tensor>()->ValueEqual(*list_tensor_value);
|
||||
} else {
|
||||
return *count_value == *list_value;
|
||||
}
|
||||
|
|
|
@ -93,7 +93,9 @@ class ListExtend : public MetaFuncGraph {
|
|||
return os;
|
||||
}
|
||||
friend bool operator==(const ListExtend &lhs, const ListExtend &rhs) { return lhs.name_ == rhs.name_; }
|
||||
void AddNodeToElems(const AbstractBasePtr &arg, const FuncGraphPtr &ret, std::vector<AnfNodePtr> *elems);
|
||||
|
||||
private:
|
||||
static void AddNodeToElems(const AbstractBasePtr &arg, const FuncGraphPtr &ret, std::vector<AnfNodePtr> *elems);
|
||||
};
|
||||
using ListExtendPtr = std::shared_ptr<ListExtend>;
|
||||
|
||||
|
|
|
@ -264,6 +264,7 @@ bool CheckPythonIsInstance(const py::object &x, const AbstractBasePtr &cmp, cons
|
|||
}
|
||||
|
||||
bool CheckIsInstanceForFunc(const py::object &x_py_obj, const AbstractBasePtr &cmp, const py::module &mod) {
|
||||
MS_EXCEPTION_IF_NULL(cmp);
|
||||
if (cmp->isa<abstract::AbstractTuple>()) {
|
||||
const auto &cmp_tuple_elements = cmp->cast<abstract::AbstractTuplePtr>()->elements();
|
||||
return std::any_of(
|
||||
|
@ -283,6 +284,7 @@ bool CheckIsInstanceForFunc(const py::object &x_py_obj, const AbstractBasePtr &c
|
|||
}
|
||||
|
||||
bool CheckIsInstanceForSparse(const AbstractBasePtr &cmp, const std::string &target) {
|
||||
MS_EXCEPTION_IF_NULL(cmp);
|
||||
if (!cmp->isa<abstract::AbstractTuple>()) {
|
||||
return cmp->ToString() == target;
|
||||
}
|
||||
|
@ -292,6 +294,7 @@ bool CheckIsInstanceForSparse(const AbstractBasePtr &cmp, const std::string &tar
|
|||
}
|
||||
|
||||
py::object GetPrimitivePyObj(const abstract::PrimitiveAbstractClosurePtr &prim_abs) {
|
||||
MS_EXCEPTION_IF_NULL(prim_abs);
|
||||
auto prim = prim_abs->prim();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_signature = prim->cast<prim::DoSignaturePrimitivePtr>();
|
||||
|
@ -303,12 +306,14 @@ py::object GetPrimitivePyObj(const abstract::PrimitiveAbstractClosurePtr &prim_a
|
|||
}
|
||||
|
||||
py::object GetMsClassPyObj(const abstract::PartialAbstractClosurePtr &ms_class_abs) {
|
||||
MS_EXCEPTION_IF_NULL(ms_class_abs);
|
||||
const auto &ms_class_args = ms_class_abs->args();
|
||||
if (ms_class_args.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "When the first input to IsInstance is PartialAbstractClosure, its args size should be 1 but "
|
||||
<< "got: " << ms_class_args.size() << ".";
|
||||
}
|
||||
auto first_arg = ms_class_args[0];
|
||||
MS_EXCEPTION_IF_NULL(first_arg);
|
||||
auto arg_type = first_arg->BuildType();
|
||||
auto arg_type_id = arg_type->type_id();
|
||||
if (arg_type_id != kObjectTypeClass) {
|
||||
|
@ -321,6 +326,7 @@ py::object GetMsClassPyObj(const abstract::PartialAbstractClosurePtr &ms_class_a
|
|||
}
|
||||
|
||||
bool CheckCmpValid(const AbstractBasePtr &cmp) {
|
||||
MS_EXCEPTION_IF_NULL(cmp);
|
||||
if (cmp->isa<abstract::AbstractSequence>()) {
|
||||
if (!cmp->isa<abstract::AbstractTuple>()) {
|
||||
return false;
|
||||
|
@ -352,7 +358,7 @@ bool CheckCmpValid(const AbstractBasePtr &cmp) {
|
|||
}
|
||||
return cmp_type_id == kMetaTypeTypeType;
|
||||
}
|
||||
return std::find(kSparsePrimStr.begin(), kSparsePrimStr.end(), cmp->ToString()) != kSparsePrimStr.end();
|
||||
return std::find(kSparsePrimStr.cbegin(), kSparsePrimStr.cend(), cmp->ToString()) != kSparsePrimStr.cend();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplIsInstance(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
@ -374,6 +380,7 @@ AbstractBasePtr InferImplIsInstance(const AnalysisEnginePtr &, const PrimitivePt
|
|||
}
|
||||
|
||||
// x is Cell object.
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
if (x->isa<abstract::FuncGraphAbstractClosure>()) {
|
||||
auto x_fg = x->cast<abstract::FuncGraphAbstractClosurePtr>()->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(x_fg);
|
||||
|
@ -611,12 +618,14 @@ AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
MS_LOG(INFO) << "The shape of dividend:" << shape_x->ToString() << ", the shape of divisor:" << div_shp->ToString();
|
||||
|
||||
auto div_shp_value = div_shp->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(div_shp_value);
|
||||
if (div_shp_value->isa<AnyValue>()) {
|
||||
MS_LOG(EXCEPTION) << "The 'tuple_div' operator shape's data field can't be anything, but got "
|
||||
<< args_spec_list[0]->ToString() << ".";
|
||||
}
|
||||
|
||||
auto shape_x_value = shape_x->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(shape_x_value);
|
||||
if (shape_x_value->isa<AnyValue>()) {
|
||||
MS_LOG(EXCEPTION) << "The 'tuple_div' operator shape's data field can't be anything, but got "
|
||||
<< args_spec_list[1]->ToString() << ".";
|
||||
|
@ -626,12 +635,16 @@ AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
MS_LOG(EXCEPTION) << "The size of inputs of 'tuple_div' operator must be the same, but the size of divisor tuple is"
|
||||
<< " " << div_shp->size() << ", the size of dividend tuple is " << shape_x->size() << ".";
|
||||
}
|
||||
|
||||
auto shape_x_data = shape_x_value->cast<ValueTuplePtr>()->value();
|
||||
auto div_shape_data = div_shp_value->cast<ValueTuplePtr>()->value();
|
||||
auto shape_x_tuple_value = shape_x_value->cast<ValueTuplePtr>();
|
||||
auto div_shape_tuple_value = div_shp_value->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_x_tuple_value);
|
||||
MS_EXCEPTION_IF_NULL(div_shape_tuple_value);
|
||||
auto shape_x_data = shape_x_tuple_value->value();
|
||||
auto div_shape_data = div_shape_tuple_value->value();
|
||||
AbstractBasePtrList values;
|
||||
|
||||
for (size_t i = 0; i < div_shape_data.size(); i++) {
|
||||
MS_EXCEPTION_IF_NULL(div_shape_data[i]);
|
||||
if (div_shape_data[i]->cast<Int64ImmPtr>() == nullptr) {
|
||||
auto value_type = div_shape_data[i]->type();
|
||||
std::string str_type;
|
||||
|
@ -643,8 +656,8 @@ AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
MS_LOG(EXCEPTION) << "The data type of inputs of 'tuple_div' operator should be an int64 number, but got a "
|
||||
<< str_type << " number " << div_shape_data[i]->ToString() << ".";
|
||||
}
|
||||
int64_t shapex_value = GetValue<int64_t>(shape_x_data[i]);
|
||||
int64_t div_value = GetValue<int64_t>(div_shape_data[i]);
|
||||
auto shapex_value = GetValue<int64_t>(shape_x_data[i]);
|
||||
auto div_value = GetValue<int64_t>(div_shape_data[i]);
|
||||
MS_LOG(DEBUG) << "div_shp_shape data shapex_value :" << shapex_value << " div_value: " << div_value;
|
||||
if (div_value == 0) {
|
||||
MS_LOG(EXCEPTION) << "The divisor value should not be 0!";
|
||||
|
|
|
@ -794,6 +794,10 @@ AnfNodePtr KPynativeCellImpl::BuildKNodeForCNodeInput(const PynativeAdjointPtr &
|
|||
if (input_index < 1) {
|
||||
MS_EXCEPTION(ValueError) << "The input_index is smaller than 1.";
|
||||
}
|
||||
if (input_index > cnode_adjoint->op_args().size()) {
|
||||
MS_EXCEPTION(ValueError) << "The input_index: " << input_index
|
||||
<< " out of range:" << cnode_adjoint->op_args().size();
|
||||
}
|
||||
return NewValueNode(cnode_adjoint->op_args()[input_index - 1]);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -212,7 +212,7 @@ class ChoicePartialEliminater : public AnfVisitor {
|
|||
}
|
||||
|
||||
private:
|
||||
std::vector<AnfNodePtr> ArgsUnion(const std::vector<AnfNodePtrList> args_list) {
|
||||
static std::vector<AnfNodePtr> ArgsUnion(const std::vector<AnfNodePtrList> args_list) {
|
||||
std::set<AnfNodePtr> no_monad_args;
|
||||
std::set<AnfNodePtr> monad_args;
|
||||
for (const auto &args : args_list) {
|
||||
|
@ -230,8 +230,8 @@ class ChoicePartialEliminater : public AnfVisitor {
|
|||
return union_args;
|
||||
}
|
||||
|
||||
HashMap<FuncGraphPtr, HashMap<AnfNodePtr, size_t>> GenOldArgsIndexes(const AnfNodePtrList &fg_list,
|
||||
const std::vector<AnfNodePtrList> &args_list) {
|
||||
static HashMap<FuncGraphPtr, HashMap<AnfNodePtr, size_t>> GenOldArgsIndexes(
|
||||
const AnfNodePtrList &fg_list, const std::vector<AnfNodePtrList> &args_list) {
|
||||
HashMap<FuncGraphPtr, HashMap<AnfNodePtr, size_t>> old_args_indexes;
|
||||
for (size_t i = 0; i < fg_list.size(); ++i) {
|
||||
const auto func_graph = GetValueNode<FuncGraphPtr>(fg_list[i]);
|
||||
|
@ -239,16 +239,16 @@ class ChoicePartialEliminater : public AnfVisitor {
|
|||
const auto &args = args_list[i];
|
||||
HashMap<AnfNodePtr, size_t> args_indexes;
|
||||
size_t arg_index = 0;
|
||||
std::for_each(args.cbegin(), args.cend(), [&args_indexes, &arg_index](const AnfNodePtr &arg) {
|
||||
for (const auto &arg : args) {
|
||||
(void)args_indexes.emplace(arg, arg_index++);
|
||||
});
|
||||
}
|
||||
old_args_indexes[func_graph] = args_indexes;
|
||||
}
|
||||
return old_args_indexes;
|
||||
}
|
||||
|
||||
AnfNodePtr GetParameterByArg(const HashMap<FuncGraphPtr, HashMap<AnfNodePtr, size_t>> &all_old_args_index_map,
|
||||
const AnfNodePtr &arg) {
|
||||
static AnfNodePtr GetParameterByArg(const HashMap<FuncGraphPtr, HashMap<AnfNodePtr, size_t>> &all_old_args_index_map,
|
||||
const AnfNodePtr &arg) {
|
||||
MS_LOG(DEBUG) << "Get parameter by arg:" << arg->DebugString();
|
||||
for (const auto &[fg, old_args_index] : all_old_args_index_map) {
|
||||
auto it = old_args_index.find(arg);
|
||||
|
@ -352,7 +352,7 @@ class SwitchPartialEliminater : public ChoicePartialEliminater {
|
|||
// Create switch call.
|
||||
TraceGuard guard2(std::make_shared<TraceCopy>(switch_call->debug_info()));
|
||||
AnfNodePtrList switch_call_inputs{new_switch_cnode};
|
||||
switch_call_inputs.insert(switch_call_inputs.end(), new_args.begin(), new_args.end());
|
||||
(void)switch_call_inputs.insert(switch_call_inputs.end(), new_args.begin(), new_args.end());
|
||||
const auto new_call_node = switch_call->func_graph()->NewCNode(std::move(switch_call_inputs));
|
||||
new_call_node->set_abstract(switch_call->abstract());
|
||||
return new_call_node;
|
||||
|
|
|
@ -141,8 +141,8 @@ inline std::tuple<size_t, size_t, size_t> CalPosition(const OffsetIndex &offset_
|
|||
|
||||
inline InputXIndex CalInputXIndex(const OffsetIndex &offset_index, const DeformableOffsetGradDims &dims) {
|
||||
InputXIndex input_x_index;
|
||||
input_x_index.i = -1.0 * SizeToFloat(dims.pad_top);
|
||||
input_x_index.j = -1.0 * SizeToFloat(dims.pad_left);
|
||||
input_x_index.i = -1.0f * SizeToFloat(dims.pad_top);
|
||||
input_x_index.j = -1.0f * SizeToFloat(dims.pad_left);
|
||||
input_x_index.i += SizeToFloat(offset_index.offset_i * dims.stride_h + offset_index.kernel_i * dims.dilation_h);
|
||||
input_x_index.j += SizeToFloat(offset_index.offset_j * dims.stride_w + offset_index.kernel_j * dims.dilation_w);
|
||||
return input_x_index;
|
||||
|
|
|
@ -45,13 +45,13 @@ bool Compute(const ComputeParams<T, S> *params, const size_t start, const size_t
|
|||
std::vector<size_t> local_indices;
|
||||
for (size_t j = 0; j < params->indices_unit_rank_; ++j) {
|
||||
auto index = indices[i * params->indices_unit_rank_ + j];
|
||||
(void)local_indices.emplace_back(IntToSize(index));
|
||||
(void)local_indices.emplace_back(LongToSize(index));
|
||||
if (index < 0) {
|
||||
MS_LOG(ERROR) << "For '" << kKernelName
|
||||
<< "', each element in 'indices' must be greater than or equal to 0, but got " << index;
|
||||
return false;
|
||||
}
|
||||
offset += IntToSize(index) * out_strides->at(j) * params->unit_size_;
|
||||
offset += LongToSize(index) * out_strides->at(j) * params->unit_size_;
|
||||
}
|
||||
if (offset * sizeof(T) > params->x_mem_size_) {
|
||||
MS_LOG(ERROR) << "For '" << kKernelName
|
||||
|
|
|
@ -56,6 +56,7 @@ Primitive::Primitive(const Primitive &prim)
|
|||
prim_type_(prim.prim_type_),
|
||||
record_evaluate_add_attr_(false),
|
||||
is_const_prim_(false),
|
||||
const_input_indexes_(prim.const_input_indexes_),
|
||||
id_(prim.id_) {}
|
||||
|
||||
Primitive &Primitive::operator=(const Primitive &other) {
|
||||
|
@ -72,6 +73,7 @@ Primitive &Primitive::operator=(const Primitive &other) {
|
|||
record_evaluate_add_attr_ = false;
|
||||
is_const_prim_ = false;
|
||||
id_ = other.id_;
|
||||
const_input_indexes_ = other.const_input_indexes_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue