!49402 Optimize predicetoutmap and predictout check
Merge pull request !49402 from NaCN/pre_fix
This commit is contained in:
commit
dfbee4567e
|
@ -87,7 +87,6 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4
|
|||
{"CTCGreedyDecoder", kTupleTensor4},
|
||||
{"CTCLoss", kTupleTensor2},
|
||||
{"CTCLossV2", kTupleTensor2},
|
||||
{"CdistGrad", kAnyType},
|
||||
{"Coalesce", kTupleTensor3},
|
||||
{"ConcatOffset", kAnyType},
|
||||
{"CombinedNonMaxSuppression", kTupleTensor4},
|
||||
|
@ -177,7 +176,7 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4
|
|||
{"MaximumGradGrad", kTupleTensor3},
|
||||
{"Median", kTupleTensor2},
|
||||
{"MedianGrad", kTupleTensor2},
|
||||
{"Meshgrid", kAnyType},
|
||||
{"Meshgrid", kTuple},
|
||||
{"MinMaxUpdatePerChannel", kTupleTensor2},
|
||||
{"MinMaxUpdatePerLayer", kTupleTensor2},
|
||||
{"MinimumGrad", kTupleTensor2},
|
||||
|
@ -247,7 +246,7 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4
|
|||
{"SparseSparseMinimum", kTupleTensor2},
|
||||
{"SparseSplit", kTuple},
|
||||
{"SparseTensorToCSRSparseMatrix", kTupleTensor5},
|
||||
{"Split", kAnyType},
|
||||
{"Split", kTuple},
|
||||
{"SquareSumAll", kTupleTensor2},
|
||||
{"SquareSumV2", kTupleTensor2},
|
||||
{"Sspaddmm", kTupleTensor3},
|
||||
|
@ -262,8 +261,8 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4
|
|||
{"Unique", kTupleTensor2},
|
||||
{"UniqueConsecutive", kTupleTensor3},
|
||||
{"UniqueWithPad", kTupleTensor2},
|
||||
{"Unpack", kAnyType},
|
||||
{"Unstack", kAnyType},
|
||||
{"Unpack", kTuple},
|
||||
{"Unstack", kTuple},
|
||||
{"bit_and", kAnyType},
|
||||
{"bit_or", kAnyType},
|
||||
{"make_range", kAnyType},
|
||||
|
|
|
@ -106,7 +106,9 @@ const AnfNodePtr BatchNorm2BNInfer::Process(const FuncGraphPtr &graph, const Anf
|
|||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto bn_infer = CreateBNInfer(graph, cnode);
|
||||
TransferDependOrUpdateState(cnode, graph, bn_infer);
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
if (kernel_graph->is_from_single_op()) {
|
||||
const auto ori_inputs = cnode->inputs();
|
||||
if (ori_inputs.size() < kBatchNormInputNum) {
|
||||
MS_LOG(EXCEPTION) << "BatchNorm's inputs size is less than 5.";
|
||||
|
|
|
@ -181,8 +181,10 @@ py::object TensorNode::GetDtype() {
|
|||
}
|
||||
|
||||
bool TensorNode::SetAbstract(const AbstractBasePtr &abs) {
|
||||
if (!abs->isa<abstract::AbstractTensor>()) {
|
||||
return false;
|
||||
if (!abs->isa<abstract::AbstractTensor>() && !abs->isa<abstract::AbstractMapTensor>()) {
|
||||
if (!abs->isa<abstract::AbstractScalar>() || abs->BuildValue() != kAnyValue) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return StubNode::SetAbstract(abs);
|
||||
}
|
||||
|
@ -223,7 +225,7 @@ bool SequenceNode::SetAbstract(const AbstractBasePtr &abs) {
|
|||
void SequenceNode::SetValue(const ValuePtr &val) {
|
||||
auto seq_value = val->cast<ValueSequencePtr>();
|
||||
auto children = seq_value->value();
|
||||
for (size_t i = 0; i < elements_.size(); ++i) {
|
||||
for (size_t i = 0; i < children.size(); ++i) {
|
||||
elements_[i]->SetValue(children[i]);
|
||||
elements_[i]->SetTopNode(nullptr);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue