fixed 4e7c012 from https://gitee.com/zh_qh/mindspore/pulls/28963
Handle the primitives who transparent pass or partial use tuple/list.
This commit is contained in:
parent
c43e48b563
commit
0deb008237
|
@ -51,9 +51,6 @@ AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const Abst
|
|||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
auto input_x = CheckArg<T>(op_name, args_spec_list, 0);
|
||||
auto input_y = CheckArg<T>(op_name, args_spec_list, 1);
|
||||
SetSequenceElementsUseFlags(input_x, true);
|
||||
SetSequenceElementsUseFlags(input_y, true);
|
||||
|
||||
ValuePtr x_value = input_x->BuildValue();
|
||||
ValuePtr y_value = input_y->BuildValue();
|
||||
return std::make_shared<AbstractScalar>(*x_value == *y_value);
|
||||
|
@ -298,8 +295,6 @@ AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const
|
|||
|
||||
return std::make_shared<AbstractTuple>(elem_list);
|
||||
}
|
||||
SetSequenceElementsUseFlags(arg_x, true);
|
||||
SetSequenceElementsUseFlags(arg_y, true);
|
||||
return BroadcastGradientArgsDiff(x_shape, y_shape);
|
||||
}
|
||||
|
||||
|
@ -354,7 +349,6 @@ AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const Primi
|
|||
MS_EXCEPTION_IF_NULL(result2);
|
||||
MS_EXCEPTION_IF_NULL(result1->abstract());
|
||||
MS_EXCEPTION_IF_NULL(result2->abstract());
|
||||
SetSequenceElementsUseFlags(lst, true);
|
||||
return result1->abstract()->Join(result2->abstract());
|
||||
}
|
||||
|
||||
|
@ -370,7 +364,6 @@ AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const Primitiv
|
|||
AbstractBasePtrList elem_list;
|
||||
(void)std::transform(tuple_elements.rbegin(), tuple_elements.rend(), std::back_inserter(elem_list),
|
||||
[](const AbstractBasePtr &elem) { return elem->Clone(); });
|
||||
SetSequenceElementsUseFlags(input, true);
|
||||
return std::make_shared<AbstractTuple>(elem_list);
|
||||
}
|
||||
|
||||
|
@ -412,7 +405,6 @@ AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitiveP
|
|||
}
|
||||
auto axis_value_ptr = axis_value->cast<ValueSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(axis_value_ptr);
|
||||
SetSequenceElementsUseFlags(shape_x, true);
|
||||
return DoInferReduceShape(shape_x, x_shp_value, axis_value_ptr, primitive);
|
||||
}
|
||||
|
||||
|
@ -477,8 +469,6 @@ AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
auto result_v = MakeValue(result);
|
||||
values.push_back(std::make_shared<AbstractScalar>(result_v, result_v->type()));
|
||||
}
|
||||
SetSequenceElementsUseFlags(shape_x, true);
|
||||
SetSequenceElementsUseFlags(div_shp, true);
|
||||
return std::make_shared<AbstractTuple>(values);
|
||||
}
|
||||
|
||||
|
|
|
@ -52,70 +52,20 @@ namespace abstract {
|
|||
using mindspore::parse::PyObjectWrapper;
|
||||
|
||||
mindspore::HashSet<std::string> prims_to_skip_undetermined_infer{
|
||||
"MakeTuple", "make_list", "Switch", "env_setitem", "env_getitem", "Load", "UpdateState"};
|
||||
prim::kPrimMakeTuple->name(), prim::kPrimMakeList->name(), prim::kPrimSwitch->name(),
|
||||
prim::kPrimEnvironSet->name(), prim::kPrimEnvironGet->name(), prim::kPrimLoad->name(),
|
||||
prim::kPrimUpdateState->name()};
|
||||
|
||||
// The Python primitives who use tuple/list elements.
|
||||
// We consider all tuple/list arguments are used by now.
|
||||
// Should check 'tuple argument index' and 'element use index' later.
|
||||
mindspore::HashSet<std::string> prims_use_sequence_elements{
|
||||
prim::kPrimStack->name(),
|
||||
prim::kPrimConcat->name(),
|
||||
prim::kPrimTupleToArray->name(),
|
||||
prim::kPrimPack->name(),
|
||||
prim::kPrimSlice->name(),
|
||||
prim::kPrimStridedSlice->name(),
|
||||
prim::kPrimScatterNd->name(),
|
||||
prim::kPrimReshape->name(),
|
||||
prim::kPrimTile->name(),
|
||||
prim::kPrimConv3DBackpropFilter->name(),
|
||||
prim::kPrimCentralization->name(),
|
||||
prim::kPrimMerge->name(),
|
||||
prim::kPrimCustom->name(),
|
||||
prim::kPrimAssert->name(),
|
||||
prim::kPrimReduceMean->name(),
|
||||
prim::kPrimReduceSum->name(),
|
||||
prim::kPrimReduceAll->name(),
|
||||
prim::kPrimReduceAny->name(),
|
||||
prim::kPrimReduceMax->name(),
|
||||
prim::kPrimReduceMin->name(),
|
||||
prim::kPrimLstm->name(),
|
||||
prim::kPrimConv3DBackpropInput->name(),
|
||||
prim::kPrimCheckBprop->name(),
|
||||
prim::kPrimPush->name(),
|
||||
prim::kPrimIdentity->name(),
|
||||
prim::kPrimPyFunc->name(),
|
||||
prim::kPrimStandardNormal->name(),
|
||||
prim::kPrimUniformReal->name(),
|
||||
prim::kPrimSparseToDense->name(),
|
||||
prim::kPrimSparseTensorDenseMatmul->name(),
|
||||
prim::kPrimBroadcast->name(),
|
||||
prim::kPrimEinsumGrad->name(),
|
||||
prim::kPrimFlattenGrad->name(),
|
||||
prim::kPrimMirror->name(),
|
||||
prim::kPrimMirrorMiniStep->name(),
|
||||
prim::kPrimMiniStepAllGather->name(),
|
||||
prim::kPrimMicroStepAllGather->name(),
|
||||
prim::kPrimVirtualDiv->name(),
|
||||
prim::kPrimVirtualAdd->name(),
|
||||
prim::kPrimVirtualDataset->name(),
|
||||
prim::kPrimVirtualOutput->name(),
|
||||
prim::kPrimFill->name(),
|
||||
"InvertPermutation",
|
||||
"Meshgrid",
|
||||
"TransShape",
|
||||
"ParallelConcat",
|
||||
"CudnnGRU",
|
||||
"GetModel",
|
||||
"UpdateModel",
|
||||
"ReduceProd",
|
||||
"StandardLaplace",
|
||||
"Gamma",
|
||||
"Poisson",
|
||||
"UniformInt",
|
||||
"BufferSample",
|
||||
"BufferAppend",
|
||||
"BufferGetItem",
|
||||
};
|
||||
// The Python primitives who visit tuple/list elements, but not consume all elements.
|
||||
// Including:
|
||||
// - Consume no element. For instance, MakeTuple.
|
||||
// - Consume partial elements, not all. For instance, TupleGetItem.
|
||||
// Map{"primitive name", {vector<int>:"index to transparent pass, -1 means all elements"}}
|
||||
mindspore::HashMap<std::string, std::vector<int>> prims_transparent_pass_sequence{
|
||||
{prim::kPrimReturn->name(), std::vector({0})}, {prim::kPrimDepend->name(), std::vector({0})},
|
||||
{prim::kPrimIdentity->name(), std::vector({0})}, {prim::kPrimMakeTuple->name(), std::vector({-1})},
|
||||
{prim::kPrimMakeList->name(), std::vector({-1})}, {prim::kPrimListAppend->name(), std::vector({0})},
|
||||
{prim::kPrimTupleGetItem->name(), std::vector({0})}, {prim::kPrimListGetItem->name(), std::vector({0})}};
|
||||
|
||||
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
|
@ -689,13 +639,14 @@ void CheckCustomPrimOutputInferResult(const PrimitivePtr &prim, const AbstractBa
|
|||
|
||||
AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) {
|
||||
// Convert to AbstractValue based on type and shape
|
||||
auto out_dtype = output[ATTR_DTYPE];
|
||||
if (output[ATTR_VALUE].is_none()) {
|
||||
return MakePyInferRes2Abstract(output);
|
||||
}
|
||||
|
||||
// Convert pyobject to Value, then to AbstractValue
|
||||
ValuePtr converted_ret = nullptr;
|
||||
auto out_dtype = output[ATTR_DTYPE];
|
||||
TypePtr dtype = py::isinstance<Type>(out_dtype) ? out_dtype.cast<TypePtr>() : nullptr;
|
||||
ValuePtr converted_ret = nullptr;
|
||||
bool converted = parse::ConvertData(output[ATTR_VALUE], &converted_ret, false, dtype);
|
||||
if (!converted) {
|
||||
MS_LOG(EXCEPTION) << "Convert data failed";
|
||||
|
@ -804,7 +755,61 @@ EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &en
|
|||
return eval_result;
|
||||
}
|
||||
|
||||
namespace {
|
||||
void CheckSequenceArgumentForCppPrimitive(const PrimitivePtr &prim, const AbstractBasePtrList &args) {
|
||||
// To check tuple/list operations with a white list of Python primitive.
|
||||
auto iter = prims_transparent_pass_sequence.find(prim->name());
|
||||
if (iter == prims_transparent_pass_sequence.end()) {
|
||||
// The primitive use all elements of each argument.
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
if (args[i]->isa<abstract::AbstractSequence>()) {
|
||||
MS_LOG(DEBUG) << "Primitive \'" << prim->name() << "\' is consuming tuple/list arguments[" << i
|
||||
<< "]: " << args[i]->ToString();
|
||||
SetSequenceElementsUseFlags(args[i], true);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// It's transparent pass primitive or using partial elements primitive.
|
||||
auto index_list = iter->second;
|
||||
if (index_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The primitive list should not be empty for " << prim->name();
|
||||
}
|
||||
// Ignore all arguments, no need checking if AbstractSequence.
|
||||
if (index_list[0] == -1) {
|
||||
return;
|
||||
}
|
||||
// Check the specific arguments index.
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
if (!args[i]->isa<abstract::AbstractSequence>()) {
|
||||
continue;
|
||||
}
|
||||
if (std::find(index_list.begin(), index_list.end(), i) == index_list.end()) {
|
||||
// For current tuple/list argument, it's not a primitive of total transparent pass or partial element use.
|
||||
MS_LOG(DEBUG) << "Primitive \'" << prim->name() << "\' is consuming specific tuple/list arguments[" << i
|
||||
<< "]: " << args[i]->ToString();
|
||||
SetSequenceElementsUseFlags(args[i], true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckSequenceArgumentForPythonPrimitive(const PrimitivePtr &prim, const AbstractBasePtrList &args) {
|
||||
// Consider all primitive implemented python infer() real use the tuple/list arguments.
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
if (args[i]->isa<abstract::AbstractSequence>()) {
|
||||
MS_LOG(DEBUG) << "Primitive \'" << prim->name() << "\' is consuming tuple/list arguments[" << i
|
||||
<< "]: " << args[i]->ToString();
|
||||
SetSequenceElementsUseFlags(args[i], true);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
|
||||
// To check tuple/list operations with a white list of Python primitive.
|
||||
CheckSequenceArgumentForCppPrimitive(prim_, args);
|
||||
|
||||
if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) {
|
||||
auto ret_abstract = AbstractEval(args);
|
||||
if (ret_abstract != nullptr) {
|
||||
|
@ -846,6 +851,9 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
|
|||
}
|
||||
|
||||
EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
|
||||
// Consider all primitive implemented python infer() real use the tuple/list arguments.
|
||||
CheckSequenceArgumentForPythonPrimitive(prim_py_, args);
|
||||
|
||||
// Ensure input arguments are evaluated.
|
||||
auto ret_abstract = AbstractEval(args);
|
||||
if (ret_abstract != nullptr) {
|
||||
|
@ -875,9 +883,9 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, con
|
|||
prim_py_->EndRecordAddAttr();
|
||||
const auto &added_attrs = prim_py_->evaluate_added_attrs();
|
||||
MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output);
|
||||
auto res_spec = PyInferRes2Abstract(prim_py_, output);
|
||||
MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << ".";
|
||||
eval_result = std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
|
||||
auto res_abs = PyInferRes2Abstract(prim_py_, output);
|
||||
MS_LOG(DEBUG) << "Python InferTensor result abstract: " << res_abs->ToString();
|
||||
eval_result = std::make_shared<EvalResult>(res_abs, std::make_shared<AttrValueMap>(added_attrs));
|
||||
// Save result to global primitive eval cache.
|
||||
if (enable_global_cache) {
|
||||
eval_cache_->Put(prim_py_, std::move(input_attrs), args, eval_result);
|
||||
|
@ -886,13 +894,6 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, con
|
|||
// Save result to evaluator cache.
|
||||
evaluator_cache_mgr_->SetValue(args, eval_result);
|
||||
}
|
||||
// To check tuple/list operations with a white list of Python primitive.
|
||||
if (prims_use_sequence_elements.find(prim_py_->name()) != prims_use_sequence_elements.end()) {
|
||||
// Set all used flags of tuple as true.
|
||||
for (auto &arg : args) {
|
||||
SetSequenceElementsUseFlags(arg, true);
|
||||
}
|
||||
}
|
||||
return eval_result;
|
||||
}
|
||||
|
||||
|
|
|
@ -303,7 +303,6 @@ AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const Abst
|
|||
// Inputs: a tuple or list or dict.
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto arg = CheckArg<T>(op_name, args_spec_list, 0);
|
||||
SetSequenceElementsUseFlags(arg, true);
|
||||
return std::make_shared<AbstractScalar>(SizeToLong(arg->size()));
|
||||
}
|
||||
} // namespace abstract
|
||||
|
|
|
@ -115,8 +115,6 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti
|
|||
(void)std::transform(res.begin(), res.end(), std::back_inserter(elems), [](int64_t n) -> AbstractBasePtr {
|
||||
return std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(n), kInt64);
|
||||
});
|
||||
SetSequenceElementsUseFlags(xs, true);
|
||||
SetSequenceElementsUseFlags(ys, true);
|
||||
return std::make_shared<AbstractTuple>(elems);
|
||||
}
|
||||
|
||||
|
@ -137,8 +135,6 @@ AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &pr
|
|||
arg = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||
tuple_len = arg->elements().size();
|
||||
tensor_base = CheckArg<AbstractTensor>(op_name, arg->elements(), 0);
|
||||
// For Stack(tuple), set all used flags of tuple as true.
|
||||
SetSequenceElementsUseFlags(args_spec_list[0], true);
|
||||
} else if (args_spec_list[0]->isa<AbstractTensor>()) {
|
||||
tuple_len = args_spec_list.size();
|
||||
tensor_base = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
|
@ -1369,8 +1365,6 @@ AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const Primitiv
|
|||
min_shape[0] = 1;
|
||||
max_shape[0] = indices_total_size * EXPAND_MAX;
|
||||
}
|
||||
SetSequenceElementsUseFlags(input_tuple, true);
|
||||
SetSequenceElementsUseFlags(input_tuple_1, true);
|
||||
return std::make_shared<AbstractTensor>(infer_type,
|
||||
std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape));
|
||||
}
|
||||
|
@ -1381,9 +1375,6 @@ AbstractBasePtr InferImplTensorCopySlices(const AnalysisEnginePtr &, const Primi
|
|||
constexpr auto kTensorCopySlicesInputNum = 5;
|
||||
CheckArgsSize(op_name, args_spec_list, kTensorCopySlicesInputNum);
|
||||
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
for (size_t i = 2; i < args_spec_list.size(); ++i) {
|
||||
SetSequenceElementsUseFlags(args_spec_list[i], true);
|
||||
}
|
||||
return std::make_shared<AbstractTensor>(input->element(), input->shape());
|
||||
}
|
||||
} // namespace abstract
|
||||
|
|
|
@ -191,9 +191,6 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|||
return args_spec_list[0];
|
||||
}
|
||||
|
||||
// For F.depend(x, MakeTuple()) or F.depend(x, tuple), set all used flags of tuple as true.
|
||||
SetSequenceElementsUseFlags(dependant_abstract, true);
|
||||
|
||||
auto depends = args_spec_list[0]->Broaden(); // Avoid eliminating the dependent node.
|
||||
if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
|
||||
// For scalar, need to set value to kAnyValue, because broaden scalar will not change the value.
|
||||
|
@ -210,12 +207,6 @@ AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitiveP
|
|||
MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at least 1, but got 0";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
|
||||
// For UpdateState(x, MakeTuple()) or UpdateState(x, tuple), set all used flags of tuple as true.
|
||||
for (size_t i = 1; i < args_spec_list.size(); i++) {
|
||||
SetSequenceElementsUseFlags(args_spec_list[i], true);
|
||||
}
|
||||
|
||||
return args_spec_list[0]->Broaden();
|
||||
}
|
||||
|
||||
|
@ -279,7 +270,6 @@ AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const Primitiv
|
|||
ret->set_indices(indices);
|
||||
ret->set_values(values);
|
||||
ret->set_dense_shape(dense_shape);
|
||||
SetSequenceElementsUseFlags(dense_shape, true);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -383,7 +373,6 @@ AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const Primi
|
|||
ret->set_indices(indices);
|
||||
ret->set_values(values);
|
||||
ret->set_dense_shape(dense_shape);
|
||||
SetSequenceElementsUseFlags(dense_shape, true);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -597,7 +586,6 @@ AbstractBasePtr InferImplMakeCSRTensor(const AnalysisEnginePtr &, const Primitiv
|
|||
ret->set_indices(indices);
|
||||
ret->set_values(values);
|
||||
ret->set_dense_shape(shape);
|
||||
SetSequenceElementsUseFlags(shape, true);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
@ -99,9 +99,6 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP
|
|||
}
|
||||
}
|
||||
|
||||
// For F.SwitchLayer(MakeTuple(), x) or F.depend(tuple, x), set all used flags of tuple as true.
|
||||
SetSequenceElementsUseFlags(branches_abs, true);
|
||||
|
||||
auto b = branches[0];
|
||||
// Return AbstractFuncUnion, otherwise the switch_layer will be replaced by branches[0]
|
||||
// which will cancel the out of bound checking for index
|
||||
|
|
|
@ -60,8 +60,6 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
auto key_string = GetValue<std::string>(keyPtr);
|
||||
(void)key_value.emplace_back(key_string, value_list[index]);
|
||||
}
|
||||
SetSequenceElementsUseFlags(keys, true);
|
||||
SetSequenceElementsUseFlags(values, true);
|
||||
return std::make_shared<AbstractDictionary>(key_value);
|
||||
}
|
||||
|
||||
|
@ -167,7 +165,6 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra
|
|||
constexpr int target_value_index = 2;
|
||||
elements[index_unsigned_value] = args_spec_list[target_value_index];
|
||||
MS_LOG(DEBUG) << "SetItem use flags, index: " << index_unsigned_value << ", for " << queue->ToString();
|
||||
SetSequenceElementsUseFlags(queue, index_unsigned_value, true);
|
||||
return std::make_shared<T>(elements, queue->sequence_nodes());
|
||||
}
|
||||
|
||||
|
|
|
@ -96,7 +96,6 @@ AbstractBasePtr AccumulateNV2Infer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
SetSequenceElementsUseFlags(input_args[0], true);
|
||||
auto infer_type = AccumulateNV2InferType(primitive, input_args);
|
||||
auto infer_shape = AccumulateNV2InferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
|
|
|
@ -98,10 +98,7 @@ AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto res = abstract::MakeAbstract(AddNInferShape(primitive, input_args), AddNInferType(primitive, input_args));
|
||||
// For AddN(MakeTuple()) or AddN(tuple), set all used flags of tuple as true.
|
||||
SetSequenceElementsUseFlags(input_args[0], true);
|
||||
return res;
|
||||
return abstract::MakeAbstract(AddNInferShape(primitive, input_args), AddNInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(AddN, prim::kPrimAddN, AddNInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -39,7 +39,6 @@ abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
|
|||
prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
SetSequenceElementsUseFlags(item, true);
|
||||
}
|
||||
auto log_probs_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto targets_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
|
||||
|
|
|
@ -37,7 +37,6 @@ abstract::ShapePtr CTCLossV2GradInferShape(const PrimitivePtr &primitive,
|
|||
prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
SetSequenceElementsUseFlags(item, true);
|
||||
}
|
||||
auto log_probs_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
|
||||
auto log_probs_shape = log_probs_shape_map[kShape];
|
||||
|
|
|
@ -118,14 +118,11 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
} // namespace
|
||||
AbstractBasePtr DynamicResizeNearestNeighborInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
constexpr size_t size_index = 1;
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(CheckAndConvertUtils::GetRemoveMonadAbsNum(input_args)),
|
||||
kEqual, input_num, prim_name);
|
||||
auto res = abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
// Set all used flags of tuple as true.
|
||||
SetSequenceElementsUseFlags(input_args[size_index], true);
|
||||
return res;
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(DynamicResizeNearestNeighbor, prim::kPrimDynamicResizeNearestNeighbor,
|
||||
|
|
|
@ -240,8 +240,6 @@ AbstractBasePtr EinsumInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto res =
|
||||
std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), InferShape(primitive, input_args));
|
||||
// For Einsum(tuple/list), set all used flags of tuple as true.
|
||||
SetSequenceElementsUseFlags(input_args[0], true);
|
||||
return res;
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Einsum, prim::kPrimEinsum, EinsumInfer, nullptr, true);
|
||||
|
|
|
@ -60,11 +60,8 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
|
|||
|
||||
AbstractBasePtr AvgPool3DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
constexpr size_t origin_input_size_index = 0;
|
||||
auto res = std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
// Set all used flags of tuple as true.
|
||||
SetSequenceElementsUseFlags(input_args[origin_input_size_index], true);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
|
|
@ -241,8 +241,6 @@ AbstractBasePtr Conv2DBackpropFilterInfer(const abstract::AnalysisEnginePtr &, c
|
|||
}
|
||||
auto res = std::make_shared<abstract::AbstractTensor>(Conv2DBackpropFilterInferType(primitive, input_args),
|
||||
Conv2DBackpropFilterInferShape(primitive, input_args));
|
||||
// Set all used flags of tuple as true.
|
||||
SetSequenceElementsUseFlags(input_args[kFilterSizeIndex], true);
|
||||
return res;
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropFilter, prim::kPrimConv2DBackpropFilter, Conv2DBackpropFilterInfer, nullptr,
|
||||
|
|
|
@ -40,7 +40,6 @@ float DropoutGrad::get_keep_prob() const {
|
|||
AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
constexpr size_t shape_index = 0;
|
||||
auto op_name = primitive->name();
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, op_name);
|
||||
|
@ -55,8 +54,6 @@ AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const Prim
|
|||
auto out_type = CheckAndConvertUtils::CheckTensorTypeValid("x", dy_type, valid_types, op_name);
|
||||
auto shape = CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, dy_index);
|
||||
auto res = abstract::MakeAbstract(shape, out_type);
|
||||
// Set all used flags of tuple as true.
|
||||
SetSequenceElementsUseFlags(input_args[shape_index], true);
|
||||
return res;
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(DropoutGrad, prim::kPrimDropoutGrad, DropoutGradInfer, nullptr, true);
|
||||
|
|
|
@ -118,19 +118,10 @@ TypePtr StridedSliceGradInferType(const PrimitivePtr &primitive, const std::vect
|
|||
AbstractBasePtr StridedSliceGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
constexpr size_t shape_index = 1;
|
||||
constexpr size_t begin_index = 2;
|
||||
constexpr size_t end_index = 3;
|
||||
constexpr size_t stride_index = 4;
|
||||
constexpr int64_t input_num = 5;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto res = abstract::MakeAbstract(StridedSliceGradInferShape(primitive, input_args),
|
||||
StridedSliceGradInferType(primitive, input_args));
|
||||
// Set all used flags of tuple as true.
|
||||
SetSequenceElementsUseFlags(input_args[shape_index], true);
|
||||
SetSequenceElementsUseFlags(input_args[begin_index], true);
|
||||
SetSequenceElementsUseFlags(input_args[end_index], true);
|
||||
SetSequenceElementsUseFlags(input_args[stride_index], true);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
|
|
@ -105,7 +105,6 @@ AbstractBasePtr LpNormInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
|
||||
auto infer_type = InferType(primitive, input_args);
|
||||
auto infer_shape = InferShape(primitive, input_args);
|
||||
SetSequenceElementsUseFlags(input_args[0], true);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(LpNorm, prim::kPrimLpNorm, LpNormInfer, nullptr, true);
|
||||
|
|
|
@ -182,16 +182,10 @@ TypePtr InferType(const PrimitivePtr &primitive) {
|
|||
} // namespace
|
||||
AbstractBasePtr NeighborExchangeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
constexpr size_t input_tuple_index = 0;
|
||||
Check(primitive, input_args);
|
||||
auto type = InferType(primitive);
|
||||
auto shape = InferShape(primitive);
|
||||
auto res = abstract::MakeAbstract(shape, type);
|
||||
// Set all used flags of tuple as true.
|
||||
if (!input_args.empty()) {
|
||||
SetSequenceElementsUseFlags(input_args[input_tuple_index], true);
|
||||
}
|
||||
return res;
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(NeighborExchange, prim::kPrimNeighborExchange, NeighborExchangeInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -57,14 +57,10 @@ TypePtr OnesInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePt
|
|||
AbstractBasePtr OnesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
constexpr size_t shape_index = 0;
|
||||
const std::string op_name = primitive->name();
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, op_name);
|
||||
auto res = abstract::MakeAbstract(OnesInferShape(primitive, input_args), OnesInferType(primitive, input_args));
|
||||
// Set all used flags of tuple as true.
|
||||
SetSequenceElementsUseFlags(input_args[shape_index], true);
|
||||
return res;
|
||||
return abstract::MakeAbstract(OnesInferShape(primitive, input_args), OnesInferType(primitive, input_args));
|
||||
}
|
||||
|
||||
ValuePtr OnesInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
|
|
@ -93,20 +93,13 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
AbstractBasePtr TransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
constexpr size_t input_perm_index = 1;
|
||||
// The second input is optional.
|
||||
constexpr size_t input_size1 = 1;
|
||||
constexpr size_t input_size2 = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("Transpose infer", SizeToLong(input_args.size()), kGreaterEqual, input_size1,
|
||||
primitive->name());
|
||||
auto type = InferType(primitive, input_args);
|
||||
auto shape = InferShape(primitive, input_args);
|
||||
auto res = abstract::MakeAbstract(shape, type);
|
||||
if (input_args.size() == input_size2) {
|
||||
SetSequenceElementsUseFlags(input_args[input_perm_index], true);
|
||||
}
|
||||
|
||||
return res;
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Transpose, prim::kPrimTranspose, TransposeInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -59,11 +59,7 @@ TypePtr ZerosInferType(const PrimitivePtr &prim, const std::vector<AbstractBaseP
|
|||
AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
constexpr size_t shape_index = 0;
|
||||
auto res = abstract::MakeAbstract(ZerosInferShape(primitive, input_args), ZerosInferType(primitive, input_args));
|
||||
// Set all used flags of tuple as true.
|
||||
SetSequenceElementsUseFlags(input_args[shape_index], true);
|
||||
return res;
|
||||
return abstract::MakeAbstract(ZerosInferShape(primitive, input_args), ZerosInferType(primitive, input_args));
|
||||
}
|
||||
|
||||
ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
|
Loading…
Reference in New Issue