Add arguments checking for AddN primitive;

In JIT Fallback, getattr(interpreted_obj, 'attr') should work as interpred_obj.attr;
Support Return InterpretedNode value;
Not throw exception if getattr's default value is InterpretedNode;
This commit is contained in:
张清华 2023-03-01 16:36:44 +08:00
parent 03ac6b3a37
commit 46e9e37d1e
6 changed files with 36 additions and 31 deletions

View File

@ -33,6 +33,7 @@
#include "frontend/operator/composite/composite.h"
#include "ir/anf.h"
#include "ir/value.h"
#include "pipeline/jit/fallback.h"
#include "pipeline/jit/parse/resolve.h"
#include "utils/hash_map.h"
#include "utils/anf_utils.h"
@ -1059,10 +1060,19 @@ class CleanAfterOptARewriter : public BaseRewriter {
return make_dict_node;
}
AnfNodePtr ConvertInterpretedObjectValue(const ValueNodePtr &node, const parse::InterpretedObjectPtr &value) {
// Convert InterpretedObject value node to PyExecute CNode.
return ConvertInterpretedObjectToPyExecute(root_graph_, value, node);
}
AnfNodePtr ConvertValueNode(const ValueNodePtr &value_node, const ValuePtr &value) override {
const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
if (value->isa<ValueDictionary>() && support_fallback_runtime) {
return RebuildValueDict(value_node, value->cast<ValueDictionaryPtr>());
if (support_fallback_runtime) {
if (value->isa<ValueDictionary>()) {
return RebuildValueDict(value_node, value->cast<ValueDictionaryPtr>());
} else if (value->isa<parse::InterpretedObject>()) {
return ConvertInterpretedObjectValue(value_node, value->cast<parse::InterpretedObjectPtr>());
}
}
bool need_convert = false;
auto convert_value = ConvertValueSequenceToValueTuple(value, 0, &need_convert);

View File

@ -1885,7 +1885,7 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
}
constexpr auto max_args_size = 3;
if (args_abs_list.size() == max_args_size) {
if (!support_fallback_runtime && args_abs_list.size() == max_args_size) {
constexpr size_t default_index = 2;
auto default_args = args_abs_list[default_index];
if (default_args->isa<abstract::AbstractScalar>()) {

View File

@ -64,11 +64,16 @@ bool AddNDynShapeJoin(ShapeVector *shape1, const ShapeVector *shape2) {
}
abstract::ShapePtr AddNInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
if (!input_args[0]->isa<abstract::AbstractTuple>() && !input_args[0]->isa<abstract::AbstractList>()) {
MS_EXCEPTION(TypeError) << "For '" << prim_name
<< "', the input data type must be list or tuple of tensors.But got:"
<< input_args[0]->ToString();
}
auto elements = input_args[0]->isa<abstract::AbstractTuple>()
? input_args[0]->cast<abstract::AbstractTuplePtr>()->elements()
: input_args[0]->cast<abstract::AbstractListPtr>()->elements();
(void)CheckAndConvertUtils::CheckInteger("input num", SizeToLong(elements.size()), kGreaterEqual, 1,
primitive->name());
(void)CheckAndConvertUtils::CheckInteger("input num", SizeToLong(elements.size()), kGreaterEqual, 1, prim_name);
(void)primitive->AddAttr("N", MakeValue(SizeToLong(elements.size())));
(void)primitive->AddAttr("n", MakeValue(SizeToLong(elements.size())));
auto shape_0 = elements[0]->BuildShape();
@ -91,21 +96,25 @@ abstract::ShapePtr AddNInferShape(const PrimitivePtr &primitive, const std::vect
}
// Join input[i] with input[0]
if (!AddNDynShapeJoin(&output_shape, &shape_vec)) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', input shape must be same, but got shape of input["
<< i << "]: " << shape->ToString() << ", shape of input[0]: " << shape_0->ToString()
<< ".";
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', input shape must be same, but got shape of input[" << i
<< "]: " << shape->ToString() << ", shape of input[0]: " << shape_0->ToString() << ".";
}
}
return std::make_shared<abstract::Shape>(output_shape);
}
TypePtr AddNInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = prim->name();
TypePtr AddNInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
if (!input_args[0]->isa<abstract::AbstractTuple>() && !input_args[0]->isa<abstract::AbstractList>()) {
MS_EXCEPTION(TypeError) << "For '" << prim_name
<< "', the input data type must be list or tuple of tensors.But got:"
<< input_args[0]->ToString();
}
auto elements = input_args[0]->isa<abstract::AbstractTuple>()
? input_args[0]->cast<abstract::AbstractTuplePtr>()->elements()
: input_args[0]->cast<abstract::AbstractListPtr>()->elements();
(void)CheckAndConvertUtils::CheckInteger("concat element num", SizeToLong(elements.size()), kGreaterEqual, 1,
prim->name());
prim_name);
std::map<std::string, TypePtr> types;
(void)types.emplace("element_0", elements[0]->BuildType());
for (size_t i = 0; i < elements.size(); ++i) {

View File

@ -112,7 +112,7 @@ _unsupported_internal_type = (
)
_hybrid_type = (
print, enumerate, zip, map, filter, abs, round, max, min, sum, hasattr, list, tuple
print, enumerate, zip, map, filter, abs, round, max, min, sum, getattr, hasattr, list, tuple
)
# Unsupported python builtin type in JIT Fallback.

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
""" test_parse_numpy """
import pytest
import numpy as np
from mindspore import nn, context, jit, Tensor
@ -44,11 +43,7 @@ def test_use_numpy_method():
return ret
net = Net()
# Not raise NotImplementedError('Mindspore not supports to use the numpy ...') any more,
# but raise RuntimeError('Should not use Python object in runtime...'), after support JIT Fallback.
with pytest.raises(RuntimeError) as err:
net()
assert "Should not use Python object in runtime" in str(err.value)
net()
def test_use_numpy_module():
@ -61,11 +56,7 @@ def test_use_numpy_module():
return ret
net = Net()
# Not raise NotImplementedError('Mindspore not supports to use the numpy ...') any more,
# but raise RuntimeError('Should not use Python object in runtime...'), after support JIT Fallback.
with pytest.raises(RuntimeError) as err:
net()
assert "Should not use Python object in runtime" in str(err.value)
net()
def test_np_calculate():

View File

@ -656,13 +656,10 @@ def test_getattr_numpy_array():
@jit
def foo():
x = np.array([1, 2, 3, 4])
# Should work as: return x.shape[0]
return getattr(x, "shape")[0]
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
with pytest.raises(TypeError) as err:
foo() # Not throw error any more, should move to ST.
assert "Do not support to get attribute" in str(err.value)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
foo()
def test_getattr_numpy_array_2():
@ -677,9 +674,7 @@ def test_getattr_numpy_array_2():
x = 1
return getattr(x, "shape", np.array([0, 1, 2, 3, 4]))
with pytest.raises(TypeError) as err:
foo()
assert "For 'getattr', the third input 'default' can not" in str(err.value)
foo()
def test_getattr_for_fg_object():