!49402 Optimize predicetoutmap and predictout check

Merge pull request !49402 from NaCN/pre_fix
This commit is contained in:
i-robot 2023-03-02 01:55:43 +00:00 committed by Gitee
commit dfbee4567e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 12 additions and 9 deletions

View File

@ -87,7 +87,6 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4
{"CTCGreedyDecoder", kTupleTensor4},
{"CTCLoss", kTupleTensor2},
{"CTCLossV2", kTupleTensor2},
{"CdistGrad", kAnyType},
{"Coalesce", kTupleTensor3},
{"ConcatOffset", kAnyType},
{"CombinedNonMaxSuppression", kTupleTensor4},
@ -177,7 +176,7 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4
{"MaximumGradGrad", kTupleTensor3},
{"Median", kTupleTensor2},
{"MedianGrad", kTupleTensor2},
{"Meshgrid", kAnyType},
{"Meshgrid", kTuple},
{"MinMaxUpdatePerChannel", kTupleTensor2},
{"MinMaxUpdatePerLayer", kTupleTensor2},
{"MinimumGrad", kTupleTensor2},
@ -247,7 +246,7 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4
{"SparseSparseMinimum", kTupleTensor2},
{"SparseSplit", kTuple},
{"SparseTensorToCSRSparseMatrix", kTupleTensor5},
{"Split", kAnyType},
{"Split", kTuple},
{"SquareSumAll", kTupleTensor2},
{"SquareSumV2", kTupleTensor2},
{"Sspaddmm", kTupleTensor3},
@ -262,8 +261,8 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4
{"Unique", kTupleTensor2},
{"UniqueConsecutive", kTupleTensor3},
{"UniqueWithPad", kTupleTensor2},
{"Unpack", kAnyType},
{"Unstack", kAnyType},
{"Unpack", kTuple},
{"Unstack", kTuple},
{"bit_and", kAnyType},
{"bit_or", kAnyType},
{"make_range", kAnyType},

View File

@ -106,7 +106,9 @@ const AnfNodePtr BatchNorm2BNInfer::Process(const FuncGraphPtr &graph, const Anf
MS_EXCEPTION_IF_NULL(ms_context);
auto bn_infer = CreateBNInfer(graph, cnode);
TransferDependOrUpdateState(cnode, graph, bn_infer);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
auto kernel_graph = graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
if (kernel_graph->is_from_single_op()) {
const auto ori_inputs = cnode->inputs();
if (ori_inputs.size() < kBatchNormInputNum) {
MS_LOG(EXCEPTION) << "BatchNorm's inputs size is less than 5.";

View File

@ -181,8 +181,10 @@ py::object TensorNode::GetDtype() {
}
bool TensorNode::SetAbstract(const AbstractBasePtr &abs) {
if (!abs->isa<abstract::AbstractTensor>()) {
return false;
if (!abs->isa<abstract::AbstractTensor>() && !abs->isa<abstract::AbstractMapTensor>()) {
if (!abs->isa<abstract::AbstractScalar>() || abs->BuildValue() != kAnyValue) {
return false;
}
}
return StubNode::SetAbstract(abs);
}
@ -223,7 +225,7 @@ bool SequenceNode::SetAbstract(const AbstractBasePtr &abs) {
void SequenceNode::SetValue(const ValuePtr &val) {
auto seq_value = val->cast<ValueSequencePtr>();
auto children = seq_value->value();
for (size_t i = 0; i < elements_.size(); ++i) {
for (size_t i = 0; i < children.size(); ++i) {
elements_[i]->SetValue(children[i]);
elements_[i]->SetTopNode(nullptr);
}