opt log about invalid primitive

This commit is contained in:
huanghui 2023-01-03 14:47:03 +08:00
parent b7661762b2
commit 8de4766f9e
2 changed files with 32 additions and 9 deletions

View File

@ -469,20 +469,23 @@ void AnalysisEngine::ClearEvaluatorCache() {
py::gil_scoped_acquire gil;
for (auto &element : evaluators_) {
EvaluatorPtr evaluator = element.second;
MS_EXCEPTION_IF_NULL(evaluator);
MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
if (evaluator == nullptr || evaluator->evaluator_cache_mgr() == nullptr) {
continue;
}
evaluator->evaluator_cache_mgr()->Clear();
}
for (auto &element : prim_constructors_) {
EvaluatorPtr evaluator = element.second;
MS_EXCEPTION_IF_NULL(evaluator);
MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
if (evaluator == nullptr || evaluator->evaluator_cache_mgr() == nullptr) {
continue;
}
evaluator->evaluator_cache_mgr()->Clear();
}
for (auto &element : prim_py_evaluators_) {
EvaluatorPtr evaluator = element.second;
MS_EXCEPTION_IF_NULL(evaluator);
MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
if (evaluator == nullptr || evaluator->evaluator_cache_mgr() == nullptr) {
continue;
}
evaluator->evaluator_cache_mgr()->Clear();
}
// Release exception to avoid hup at exit.
@ -582,7 +585,7 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbs
if (is_new) {
iter->second = GetPrimEvaluator(primitive, shared_from_this());
if (iter->second == nullptr) {
MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << primitive->name() << ").";
MS_LOG(EXCEPTION) << "Operator '" << primitive->name() << "' is invalid.";
}
}
return iter->second;

View File

@ -14,9 +14,9 @@
# ============================================================================
""" test_net_infer """
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore import Tensor, context, jit, ops
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
import mindspore.ops.operations as op
@ -124,3 +124,23 @@ def test_maybe_poly_func():
y_input = Tensor(np.array([1, 2]).astype(np.int32))
z_input = Tensor(np.array([[2, 2], [3, 3]]).astype(np.int32))
Net()(Tensor(np.array(1).astype(np.int32)), y_input, z_input)
def test_invalid_primitive():
"""
Feature: Inner primitive infer.
Description: Test invalid primitive.
Expectation: RuntimeError.
"""
context.set_context(mode=context.GRAPH_MODE)
invalid_prim = ops.Primitive("invalid_prim")
@jit
def func(x):
return invalid_prim(x)
a = Tensor([1])
with pytest.raises(RuntimeError) as ex:
func(a)
assert "Operator 'invalid_prim' is invalid." in str(
ex.value)