Handle the primitives who transparent pass or partial use tuple/list.
This commit is contained in:
Zhang Qinghua 2022-01-20 12:35:22 +08:00 committed by 张清华
parent c43e48b563
commit 0deb008237
22 changed files with 81 additions and 168 deletions

View File

@ -51,9 +51,6 @@ AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const Abst
CheckArgsSize(op_name, args_spec_list, 2);
auto input_x = CheckArg<T>(op_name, args_spec_list, 0);
auto input_y = CheckArg<T>(op_name, args_spec_list, 1);
SetSequenceElementsUseFlags(input_x, true);
SetSequenceElementsUseFlags(input_y, true);
ValuePtr x_value = input_x->BuildValue();
ValuePtr y_value = input_y->BuildValue();
return std::make_shared<AbstractScalar>(*x_value == *y_value);
@ -298,8 +295,6 @@ AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const
return std::make_shared<AbstractTuple>(elem_list);
}
SetSequenceElementsUseFlags(arg_x, true);
SetSequenceElementsUseFlags(arg_y, true);
return BroadcastGradientArgsDiff(x_shape, y_shape);
}
@ -354,7 +349,6 @@ AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const Primi
MS_EXCEPTION_IF_NULL(result2);
MS_EXCEPTION_IF_NULL(result1->abstract());
MS_EXCEPTION_IF_NULL(result2->abstract());
SetSequenceElementsUseFlags(lst, true);
return result1->abstract()->Join(result2->abstract());
}
@ -370,7 +364,6 @@ AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const Primitiv
AbstractBasePtrList elem_list;
(void)std::transform(tuple_elements.rbegin(), tuple_elements.rend(), std::back_inserter(elem_list),
[](const AbstractBasePtr &elem) { return elem->Clone(); });
SetSequenceElementsUseFlags(input, true);
return std::make_shared<AbstractTuple>(elem_list);
}
@ -412,7 +405,6 @@ AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitiveP
}
auto axis_value_ptr = axis_value->cast<ValueSequencePtr>();
MS_EXCEPTION_IF_NULL(axis_value_ptr);
SetSequenceElementsUseFlags(shape_x, true);
return DoInferReduceShape(shape_x, x_shp_value, axis_value_ptr, primitive);
}
@ -477,8 +469,6 @@ AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr
auto result_v = MakeValue(result);
values.push_back(std::make_shared<AbstractScalar>(result_v, result_v->type()));
}
SetSequenceElementsUseFlags(shape_x, true);
SetSequenceElementsUseFlags(div_shp, true);
return std::make_shared<AbstractTuple>(values);
}

View File

@ -52,70 +52,20 @@ namespace abstract {
using mindspore::parse::PyObjectWrapper;
mindspore::HashSet<std::string> prims_to_skip_undetermined_infer{
"MakeTuple", "make_list", "Switch", "env_setitem", "env_getitem", "Load", "UpdateState"};
prim::kPrimMakeTuple->name(), prim::kPrimMakeList->name(), prim::kPrimSwitch->name(),
prim::kPrimEnvironSet->name(), prim::kPrimEnvironGet->name(), prim::kPrimLoad->name(),
prim::kPrimUpdateState->name()};
// The Python primitives who use tuple/list elements.
// We consider all tuple/list arguments are used by now.
// Should check 'tuple argument index' and 'element use index' later.
mindspore::HashSet<std::string> prims_use_sequence_elements{
prim::kPrimStack->name(),
prim::kPrimConcat->name(),
prim::kPrimTupleToArray->name(),
prim::kPrimPack->name(),
prim::kPrimSlice->name(),
prim::kPrimStridedSlice->name(),
prim::kPrimScatterNd->name(),
prim::kPrimReshape->name(),
prim::kPrimTile->name(),
prim::kPrimConv3DBackpropFilter->name(),
prim::kPrimCentralization->name(),
prim::kPrimMerge->name(),
prim::kPrimCustom->name(),
prim::kPrimAssert->name(),
prim::kPrimReduceMean->name(),
prim::kPrimReduceSum->name(),
prim::kPrimReduceAll->name(),
prim::kPrimReduceAny->name(),
prim::kPrimReduceMax->name(),
prim::kPrimReduceMin->name(),
prim::kPrimLstm->name(),
prim::kPrimConv3DBackpropInput->name(),
prim::kPrimCheckBprop->name(),
prim::kPrimPush->name(),
prim::kPrimIdentity->name(),
prim::kPrimPyFunc->name(),
prim::kPrimStandardNormal->name(),
prim::kPrimUniformReal->name(),
prim::kPrimSparseToDense->name(),
prim::kPrimSparseTensorDenseMatmul->name(),
prim::kPrimBroadcast->name(),
prim::kPrimEinsumGrad->name(),
prim::kPrimFlattenGrad->name(),
prim::kPrimMirror->name(),
prim::kPrimMirrorMiniStep->name(),
prim::kPrimMiniStepAllGather->name(),
prim::kPrimMicroStepAllGather->name(),
prim::kPrimVirtualDiv->name(),
prim::kPrimVirtualAdd->name(),
prim::kPrimVirtualDataset->name(),
prim::kPrimVirtualOutput->name(),
prim::kPrimFill->name(),
"InvertPermutation",
"Meshgrid",
"TransShape",
"ParallelConcat",
"CudnnGRU",
"GetModel",
"UpdateModel",
"ReduceProd",
"StandardLaplace",
"Gamma",
"Poisson",
"UniformInt",
"BufferSample",
"BufferAppend",
"BufferGetItem",
};
// The Python primitives who visit tuple/list elements, but not consume all elements.
// Including:
// - Consume no element. For instance, MakeTuple.
// - Consume partial elements, not all. For instance, TupleGetItem.
// Map{"primitive name", {vector<int>:"index to transparent pass, -1 means all elements"}}
mindspore::HashMap<std::string, std::vector<int>> prims_transparent_pass_sequence{
{prim::kPrimReturn->name(), std::vector({0})}, {prim::kPrimDepend->name(), std::vector({0})},
{prim::kPrimIdentity->name(), std::vector({0})}, {prim::kPrimMakeTuple->name(), std::vector({-1})},
{prim::kPrimMakeList->name(), std::vector({-1})}, {prim::kPrimListAppend->name(), std::vector({0})},
{prim::kPrimTupleGetItem->name(), std::vector({0})}, {prim::kPrimListGetItem->name(), std::vector({0})}};
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) {
@ -689,13 +639,14 @@ void CheckCustomPrimOutputInferResult(const PrimitivePtr &prim, const AbstractBa
AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) {
// Convert to AbstractValue based on type and shape
auto out_dtype = output[ATTR_DTYPE];
if (output[ATTR_VALUE].is_none()) {
return MakePyInferRes2Abstract(output);
}
// Convert pyobject to Value, then to AbstractValue
ValuePtr converted_ret = nullptr;
auto out_dtype = output[ATTR_DTYPE];
TypePtr dtype = py::isinstance<Type>(out_dtype) ? out_dtype.cast<TypePtr>() : nullptr;
ValuePtr converted_ret = nullptr;
bool converted = parse::ConvertData(output[ATTR_VALUE], &converted_ret, false, dtype);
if (!converted) {
MS_LOG(EXCEPTION) << "Convert data failed";
@ -804,7 +755,61 @@ EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &en
return eval_result;
}
namespace {
void CheckSequenceArgumentForCppPrimitive(const PrimitivePtr &prim, const AbstractBasePtrList &args) {
// To check tuple/list operations with a white list of Python primitive.
auto iter = prims_transparent_pass_sequence.find(prim->name());
if (iter == prims_transparent_pass_sequence.end()) {
// The primitive use all elements of each argument.
for (size_t i = 0; i < args.size(); ++i) {
if (args[i]->isa<abstract::AbstractSequence>()) {
MS_LOG(DEBUG) << "Primitive \'" << prim->name() << "\' is consuming tuple/list arguments[" << i
<< "]: " << args[i]->ToString();
SetSequenceElementsUseFlags(args[i], true);
}
}
return;
}
// It's transparent pass primitive or using partial elements primitive.
auto index_list = iter->second;
if (index_list.empty()) {
MS_LOG(EXCEPTION) << "The primitive list should not be empty for " << prim->name();
}
// Ignore all arguments, no need checking if AbstractSequence.
if (index_list[0] == -1) {
return;
}
// Check the specific arguments index.
for (size_t i = 0; i < args.size(); ++i) {
if (!args[i]->isa<abstract::AbstractSequence>()) {
continue;
}
if (std::find(index_list.begin(), index_list.end(), i) == index_list.end()) {
// For current tuple/list argument, it's not a primitive of total transparent pass or partial element use.
MS_LOG(DEBUG) << "Primitive \'" << prim->name() << "\' is consuming specific tuple/list arguments[" << i
<< "]: " << args[i]->ToString();
SetSequenceElementsUseFlags(args[i], true);
}
}
}
void CheckSequenceArgumentForPythonPrimitive(const PrimitivePtr &prim, const AbstractBasePtrList &args) {
// Consider all primitive implemented python infer() real use the tuple/list arguments.
for (size_t i = 0; i < args.size(); ++i) {
if (args[i]->isa<abstract::AbstractSequence>()) {
MS_LOG(DEBUG) << "Primitive \'" << prim->name() << "\' is consuming tuple/list arguments[" << i
<< "]: " << args[i]->ToString();
SetSequenceElementsUseFlags(args[i], true);
}
}
}
} // namespace
EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
// To check tuple/list operations with a white list of Python primitive.
CheckSequenceArgumentForCppPrimitive(prim_, args);
if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) {
auto ret_abstract = AbstractEval(args);
if (ret_abstract != nullptr) {
@ -846,6 +851,9 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
}
EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
// Consider all primitive implemented python infer() real use the tuple/list arguments.
CheckSequenceArgumentForPythonPrimitive(prim_py_, args);
// Ensure input arguments are evaluated.
auto ret_abstract = AbstractEval(args);
if (ret_abstract != nullptr) {
@ -875,9 +883,9 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, con
prim_py_->EndRecordAddAttr();
const auto &added_attrs = prim_py_->evaluate_added_attrs();
MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output);
auto res_spec = PyInferRes2Abstract(prim_py_, output);
MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << ".";
eval_result = std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
auto res_abs = PyInferRes2Abstract(prim_py_, output);
MS_LOG(DEBUG) << "Python InferTensor result abstract: " << res_abs->ToString();
eval_result = std::make_shared<EvalResult>(res_abs, std::make_shared<AttrValueMap>(added_attrs));
// Save result to global primitive eval cache.
if (enable_global_cache) {
eval_cache_->Put(prim_py_, std::move(input_attrs), args, eval_result);
@ -886,13 +894,6 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, con
// Save result to evaluator cache.
evaluator_cache_mgr_->SetValue(args, eval_result);
}
// To check tuple/list operations with a white list of Python primitive.
if (prims_use_sequence_elements.find(prim_py_->name()) != prims_use_sequence_elements.end()) {
// Set all used flags of tuple as true.
for (auto &arg : args) {
SetSequenceElementsUseFlags(arg, true);
}
}
return eval_result;
}

View File

@ -303,7 +303,6 @@ AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const Abst
// Inputs: a tuple or list or dict.
CheckArgsSize(op_name, args_spec_list, 1);
auto arg = CheckArg<T>(op_name, args_spec_list, 0);
SetSequenceElementsUseFlags(arg, true);
return std::make_shared<AbstractScalar>(SizeToLong(arg->size()));
}
} // namespace abstract

View File

@ -115,8 +115,6 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti
(void)std::transform(res.begin(), res.end(), std::back_inserter(elems), [](int64_t n) -> AbstractBasePtr {
return std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(n), kInt64);
});
SetSequenceElementsUseFlags(xs, true);
SetSequenceElementsUseFlags(ys, true);
return std::make_shared<AbstractTuple>(elems);
}
@ -137,8 +135,6 @@ AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &pr
arg = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
tuple_len = arg->elements().size();
tensor_base = CheckArg<AbstractTensor>(op_name, arg->elements(), 0);
// For Stack(tuple), set all used flags of tuple as true.
SetSequenceElementsUseFlags(args_spec_list[0], true);
} else if (args_spec_list[0]->isa<AbstractTensor>()) {
tuple_len = args_spec_list.size();
tensor_base = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
@ -1369,8 +1365,6 @@ AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const Primitiv
min_shape[0] = 1;
max_shape[0] = indices_total_size * EXPAND_MAX;
}
SetSequenceElementsUseFlags(input_tuple, true);
SetSequenceElementsUseFlags(input_tuple_1, true);
return std::make_shared<AbstractTensor>(infer_type,
std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape));
}
@ -1381,9 +1375,6 @@ AbstractBasePtr InferImplTensorCopySlices(const AnalysisEnginePtr &, const Primi
constexpr auto kTensorCopySlicesInputNum = 5;
CheckArgsSize(op_name, args_spec_list, kTensorCopySlicesInputNum);
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
for (size_t i = 2; i < args_spec_list.size(); ++i) {
SetSequenceElementsUseFlags(args_spec_list[i], true);
}
return std::make_shared<AbstractTensor>(input->element(), input->shape());
}
} // namespace abstract

View File

@ -191,9 +191,6 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p
return args_spec_list[0];
}
// For F.depend(x, MakeTuple()) or F.depend(x, tuple), set all used flags of tuple as true.
SetSequenceElementsUseFlags(dependant_abstract, true);
auto depends = args_spec_list[0]->Broaden(); // Avoid eliminating the dependent node.
if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
// For scalar, need to set value to kAnyValue, because broaden scalar will not change the value.
@ -210,12 +207,6 @@ AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitiveP
MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at least 1, but got 0";
}
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
// For UpdateState(x, MakeTuple()) or UpdateState(x, tuple), set all used flags of tuple as true.
for (size_t i = 1; i < args_spec_list.size(); i++) {
SetSequenceElementsUseFlags(args_spec_list[i], true);
}
return args_spec_list[0]->Broaden();
}
@ -279,7 +270,6 @@ AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const Primitiv
ret->set_indices(indices);
ret->set_values(values);
ret->set_dense_shape(dense_shape);
SetSequenceElementsUseFlags(dense_shape, true);
return ret;
}
@ -383,7 +373,6 @@ AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const Primi
ret->set_indices(indices);
ret->set_values(values);
ret->set_dense_shape(dense_shape);
SetSequenceElementsUseFlags(dense_shape, true);
return ret;
}
@ -597,7 +586,6 @@ AbstractBasePtr InferImplMakeCSRTensor(const AnalysisEnginePtr &, const Primitiv
ret->set_indices(indices);
ret->set_values(values);
ret->set_dense_shape(shape);
SetSequenceElementsUseFlags(shape, true);
return ret;
}

View File

@ -99,9 +99,6 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP
}
}
// For F.SwitchLayer(MakeTuple(), x) or F.depend(tuple, x), set all used flags of tuple as true.
SetSequenceElementsUseFlags(branches_abs, true);
auto b = branches[0];
// Return AbstractFuncUnion, otherwise the switch_layer will be replaced by branches[0]
// which will cancel the out of bound checking for index

View File

@ -60,8 +60,6 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr
auto key_string = GetValue<std::string>(keyPtr);
(void)key_value.emplace_back(key_string, value_list[index]);
}
SetSequenceElementsUseFlags(keys, true);
SetSequenceElementsUseFlags(values, true);
return std::make_shared<AbstractDictionary>(key_value);
}
@ -167,7 +165,6 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra
constexpr int target_value_index = 2;
elements[index_unsigned_value] = args_spec_list[target_value_index];
MS_LOG(DEBUG) << "SetItem use flags, index: " << index_unsigned_value << ", for " << queue->ToString();
SetSequenceElementsUseFlags(queue, index_unsigned_value, true);
return std::make_shared<T>(elements, queue->sequence_nodes());
}

View File

@ -96,7 +96,6 @@ AbstractBasePtr AccumulateNV2Infer(const abstract::AnalysisEnginePtr &, const Pr
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
SetSequenceElementsUseFlags(input_args[0], true);
auto infer_type = AccumulateNV2InferType(primitive, input_args);
auto infer_shape = AccumulateNV2InferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);

View File

@ -98,10 +98,7 @@ AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto res = abstract::MakeAbstract(AddNInferShape(primitive, input_args), AddNInferType(primitive, input_args));
// For AddN(MakeTuple()) or AddN(tuple), set all used flags of tuple as true.
SetSequenceElementsUseFlags(input_args[0], true);
return res;
return abstract::MakeAbstract(AddNInferShape(primitive, input_args), AddNInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(AddN, prim::kPrimAddN, AddNInfer, nullptr, true);
} // namespace ops

View File

@ -39,7 +39,6 @@ abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
SetSequenceElementsUseFlags(item, true);
}
auto log_probs_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto targets_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());

View File

@ -37,7 +37,6 @@ abstract::ShapePtr CTCLossV2GradInferShape(const PrimitivePtr &primitive,
prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
SetSequenceElementsUseFlags(item, true);
}
auto log_probs_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
auto log_probs_shape = log_probs_shape_map[kShape];

View File

@ -118,14 +118,11 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
} // namespace
AbstractBasePtr DynamicResizeNearestNeighborInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
constexpr size_t size_index = 1;
auto prim_name = primitive->name();
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(CheckAndConvertUtils::GetRemoveMonadAbsNum(input_args)),
kEqual, input_num, prim_name);
auto res = abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
// Set all used flags of tuple as true.
SetSequenceElementsUseFlags(input_args[size_index], true);
return res;
}
REGISTER_PRIMITIVE_EVAL_IMPL(DynamicResizeNearestNeighbor, prim::kPrimDynamicResizeNearestNeighbor,

View File

@ -240,8 +240,6 @@ AbstractBasePtr EinsumInfer(const abstract::AnalysisEnginePtr &, const Primitive
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto res =
std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), InferShape(primitive, input_args));
// For Einsum(tuple/list), set all used flags of tuple as true.
SetSequenceElementsUseFlags(input_args[0], true);
return res;
}
REGISTER_PRIMITIVE_EVAL_IMPL(Einsum, prim::kPrimEinsum, EinsumInfer, nullptr, true);

View File

@ -60,11 +60,8 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
AbstractBasePtr AvgPool3DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
constexpr size_t origin_input_size_index = 0;
auto res = std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
// Set all used flags of tuple as true.
SetSequenceElementsUseFlags(input_args[origin_input_size_index], true);
return res;
}

View File

@ -241,8 +241,6 @@ AbstractBasePtr Conv2DBackpropFilterInfer(const abstract::AnalysisEnginePtr &, c
}
auto res = std::make_shared<abstract::AbstractTensor>(Conv2DBackpropFilterInferType(primitive, input_args),
Conv2DBackpropFilterInferShape(primitive, input_args));
// Set all used flags of tuple as true.
SetSequenceElementsUseFlags(input_args[kFilterSizeIndex], true);
return res;
}
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropFilter, prim::kPrimConv2DBackpropFilter, Conv2DBackpropFilterInfer, nullptr,

View File

@ -40,7 +40,6 @@ float DropoutGrad::get_keep_prob() const {
AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
constexpr size_t shape_index = 0;
auto op_name = primitive->name();
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, op_name);
@ -55,8 +54,6 @@ AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const Prim
auto out_type = CheckAndConvertUtils::CheckTensorTypeValid("x", dy_type, valid_types, op_name);
auto shape = CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, dy_index);
auto res = abstract::MakeAbstract(shape, out_type);
// Set all used flags of tuple as true.
SetSequenceElementsUseFlags(input_args[shape_index], true);
return res;
}
REGISTER_PRIMITIVE_EVAL_IMPL(DropoutGrad, prim::kPrimDropoutGrad, DropoutGradInfer, nullptr, true);

View File

@ -118,19 +118,10 @@ TypePtr StridedSliceGradInferType(const PrimitivePtr &primitive, const std::vect
AbstractBasePtr StridedSliceGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
constexpr size_t shape_index = 1;
constexpr size_t begin_index = 2;
constexpr size_t end_index = 3;
constexpr size_t stride_index = 4;
constexpr int64_t input_num = 5;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto res = abstract::MakeAbstract(StridedSliceGradInferShape(primitive, input_args),
StridedSliceGradInferType(primitive, input_args));
// Set all used flags of tuple as true.
SetSequenceElementsUseFlags(input_args[shape_index], true);
SetSequenceElementsUseFlags(input_args[begin_index], true);
SetSequenceElementsUseFlags(input_args[end_index], true);
SetSequenceElementsUseFlags(input_args[stride_index], true);
return res;
}

View File

@ -105,7 +105,6 @@ AbstractBasePtr LpNormInfer(const abstract::AnalysisEnginePtr &, const Primitive
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
auto infer_type = InferType(primitive, input_args);
auto infer_shape = InferShape(primitive, input_args);
SetSequenceElementsUseFlags(input_args[0], true);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(LpNorm, prim::kPrimLpNorm, LpNormInfer, nullptr, true);

View File

@ -182,16 +182,10 @@ TypePtr InferType(const PrimitivePtr &primitive) {
} // namespace
AbstractBasePtr NeighborExchangeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
constexpr size_t input_tuple_index = 0;
Check(primitive, input_args);
auto type = InferType(primitive);
auto shape = InferShape(primitive);
auto res = abstract::MakeAbstract(shape, type);
// Set all used flags of tuple as true.
if (!input_args.empty()) {
SetSequenceElementsUseFlags(input_args[input_tuple_index], true);
}
return res;
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(NeighborExchange, prim::kPrimNeighborExchange, NeighborExchangeInfer, nullptr, true);
} // namespace ops

View File

@ -57,14 +57,10 @@ TypePtr OnesInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePt
AbstractBasePtr OnesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
constexpr size_t shape_index = 0;
const std::string op_name = primitive->name();
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, op_name);
auto res = abstract::MakeAbstract(OnesInferShape(primitive, input_args), OnesInferType(primitive, input_args));
// Set all used flags of tuple as true.
SetSequenceElementsUseFlags(input_args[shape_index], true);
return res;
return abstract::MakeAbstract(OnesInferShape(primitive, input_args), OnesInferType(primitive, input_args));
}
ValuePtr OnesInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {

View File

@ -93,20 +93,13 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
AbstractBasePtr TransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
constexpr size_t input_perm_index = 1;
// The second input is optional.
constexpr size_t input_size1 = 1;
constexpr size_t input_size2 = 2;
(void)CheckAndConvertUtils::CheckInteger("Transpose infer", SizeToLong(input_args.size()), kGreaterEqual, input_size1,
primitive->name());
auto type = InferType(primitive, input_args);
auto shape = InferShape(primitive, input_args);
auto res = abstract::MakeAbstract(shape, type);
if (input_args.size() == input_size2) {
SetSequenceElementsUseFlags(input_args[input_perm_index], true);
}
return res;
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(Transpose, prim::kPrimTranspose, TransposeInfer, nullptr, true);
} // namespace ops

View File

@ -59,11 +59,7 @@ TypePtr ZerosInferType(const PrimitivePtr &prim, const std::vector<AbstractBaseP
AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
constexpr size_t shape_index = 0;
auto res = abstract::MakeAbstract(ZerosInferShape(primitive, input_args), ZerosInferType(primitive, input_args));
// Set all used flags of tuple as true.
SetSequenceElementsUseFlags(input_args[shape_index], true);
return res;
return abstract::MakeAbstract(ZerosInferShape(primitive, input_args), ZerosInferType(primitive, input_args));
}
ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {