Add arguments checking for AddN primitive;

In JIT Fallback, getattr(interpreted_obj, 'attr') should work as interpred_obj.attr;
Support Return InterpretedNode value;
Not throw exception if getattr's default value is InterpretedNode;
This commit is contained in:
张清华 2023-03-01 16:36:44 +08:00
parent 03ac6b3a37
commit 46e9e37d1e
6 changed files with 36 additions and 31 deletions

View File

@ -33,6 +33,7 @@
#include "frontend/operator/composite/composite.h" #include "frontend/operator/composite/composite.h"
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/value.h" #include "ir/value.h"
#include "pipeline/jit/fallback.h"
#include "pipeline/jit/parse/resolve.h" #include "pipeline/jit/parse/resolve.h"
#include "utils/hash_map.h" #include "utils/hash_map.h"
#include "utils/anf_utils.h" #include "utils/anf_utils.h"
@ -1059,10 +1060,19 @@ class CleanAfterOptARewriter : public BaseRewriter {
return make_dict_node; return make_dict_node;
} }
AnfNodePtr ConvertInterpretedObjectValue(const ValueNodePtr &node, const parse::InterpretedObjectPtr &value) {
// Convert InterpretedObject value node to PyExecute CNode.
return ConvertInterpretedObjectToPyExecute(root_graph_, value, node);
}
AnfNodePtr ConvertValueNode(const ValueNodePtr &value_node, const ValuePtr &value) override { AnfNodePtr ConvertValueNode(const ValueNodePtr &value_node, const ValuePtr &value) override {
const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0"); const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
if (value->isa<ValueDictionary>() && support_fallback_runtime) { if (support_fallback_runtime) {
return RebuildValueDict(value_node, value->cast<ValueDictionaryPtr>()); if (value->isa<ValueDictionary>()) {
return RebuildValueDict(value_node, value->cast<ValueDictionaryPtr>());
} else if (value->isa<parse::InterpretedObject>()) {
return ConvertInterpretedObjectValue(value_node, value->cast<parse::InterpretedObjectPtr>());
}
} }
bool need_convert = false; bool need_convert = false;
auto convert_value = ConvertValueSequenceToValueTuple(value, 0, &need_convert); auto convert_value = ConvertValueSequenceToValueTuple(value, 0, &need_convert);

View File

@ -1885,7 +1885,7 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
} }
constexpr auto max_args_size = 3; constexpr auto max_args_size = 3;
if (args_abs_list.size() == max_args_size) { if (!support_fallback_runtime && args_abs_list.size() == max_args_size) {
constexpr size_t default_index = 2; constexpr size_t default_index = 2;
auto default_args = args_abs_list[default_index]; auto default_args = args_abs_list[default_index];
if (default_args->isa<abstract::AbstractScalar>()) { if (default_args->isa<abstract::AbstractScalar>()) {

View File

@ -64,11 +64,16 @@ bool AddNDynShapeJoin(ShapeVector *shape1, const ShapeVector *shape2) {
} }
abstract::ShapePtr AddNInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr AddNInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
if (!input_args[0]->isa<abstract::AbstractTuple>() && !input_args[0]->isa<abstract::AbstractList>()) {
MS_EXCEPTION(TypeError) << "For '" << prim_name
<< "', the input data type must be list or tuple of tensors.But got:"
<< input_args[0]->ToString();
}
auto elements = input_args[0]->isa<abstract::AbstractTuple>() auto elements = input_args[0]->isa<abstract::AbstractTuple>()
? input_args[0]->cast<abstract::AbstractTuplePtr>()->elements() ? input_args[0]->cast<abstract::AbstractTuplePtr>()->elements()
: input_args[0]->cast<abstract::AbstractListPtr>()->elements(); : input_args[0]->cast<abstract::AbstractListPtr>()->elements();
(void)CheckAndConvertUtils::CheckInteger("input num", SizeToLong(elements.size()), kGreaterEqual, 1, (void)CheckAndConvertUtils::CheckInteger("input num", SizeToLong(elements.size()), kGreaterEqual, 1, prim_name);
primitive->name());
(void)primitive->AddAttr("N", MakeValue(SizeToLong(elements.size()))); (void)primitive->AddAttr("N", MakeValue(SizeToLong(elements.size())));
(void)primitive->AddAttr("n", MakeValue(SizeToLong(elements.size()))); (void)primitive->AddAttr("n", MakeValue(SizeToLong(elements.size())));
auto shape_0 = elements[0]->BuildShape(); auto shape_0 = elements[0]->BuildShape();
@ -91,21 +96,25 @@ abstract::ShapePtr AddNInferShape(const PrimitivePtr &primitive, const std::vect
} }
// Join input[i] with input[0] // Join input[i] with input[0]
if (!AddNDynShapeJoin(&output_shape, &shape_vec)) { if (!AddNDynShapeJoin(&output_shape, &shape_vec)) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', input shape must be same, but got shape of input[" MS_EXCEPTION(ValueError) << "For '" << prim_name << "', input shape must be same, but got shape of input[" << i
<< i << "]: " << shape->ToString() << ", shape of input[0]: " << shape_0->ToString() << "]: " << shape->ToString() << ", shape of input[0]: " << shape_0->ToString() << ".";
<< ".";
} }
} }
return std::make_shared<abstract::Shape>(output_shape); return std::make_shared<abstract::Shape>(output_shape);
} }
TypePtr AddNInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { TypePtr AddNInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = prim->name(); auto prim_name = primitive->name();
if (!input_args[0]->isa<abstract::AbstractTuple>() && !input_args[0]->isa<abstract::AbstractList>()) {
MS_EXCEPTION(TypeError) << "For '" << prim_name
<< "', the input data type must be list or tuple of tensors.But got:"
<< input_args[0]->ToString();
}
auto elements = input_args[0]->isa<abstract::AbstractTuple>() auto elements = input_args[0]->isa<abstract::AbstractTuple>()
? input_args[0]->cast<abstract::AbstractTuplePtr>()->elements() ? input_args[0]->cast<abstract::AbstractTuplePtr>()->elements()
: input_args[0]->cast<abstract::AbstractListPtr>()->elements(); : input_args[0]->cast<abstract::AbstractListPtr>()->elements();
(void)CheckAndConvertUtils::CheckInteger("concat element num", SizeToLong(elements.size()), kGreaterEqual, 1, (void)CheckAndConvertUtils::CheckInteger("concat element num", SizeToLong(elements.size()), kGreaterEqual, 1,
prim->name()); prim_name);
std::map<std::string, TypePtr> types; std::map<std::string, TypePtr> types;
(void)types.emplace("element_0", elements[0]->BuildType()); (void)types.emplace("element_0", elements[0]->BuildType());
for (size_t i = 0; i < elements.size(); ++i) { for (size_t i = 0; i < elements.size(); ++i) {

View File

@ -112,7 +112,7 @@ _unsupported_internal_type = (
) )
_hybrid_type = ( _hybrid_type = (
print, enumerate, zip, map, filter, abs, round, max, min, sum, hasattr, list, tuple print, enumerate, zip, map, filter, abs, round, max, min, sum, getattr, hasattr, list, tuple
) )
# Unsupported python builtin type in JIT Fallback. # Unsupported python builtin type in JIT Fallback.

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" test_parse_numpy """ """ test_parse_numpy """
import pytest
import numpy as np import numpy as np
from mindspore import nn, context, jit, Tensor from mindspore import nn, context, jit, Tensor
@ -44,11 +43,7 @@ def test_use_numpy_method():
return ret return ret
net = Net() net = Net()
# Not raise NotImplementedError('Mindspore not supports to use the numpy ...') any more, net()
# but raise RuntimeError('Should not use Python object in runtime...'), after support JIT Fallback.
with pytest.raises(RuntimeError) as err:
net()
assert "Should not use Python object in runtime" in str(err.value)
def test_use_numpy_module(): def test_use_numpy_module():
@ -61,11 +56,7 @@ def test_use_numpy_module():
return ret return ret
net = Net() net = Net()
# Not raise NotImplementedError('Mindspore not supports to use the numpy ...') any more, net()
# but raise RuntimeError('Should not use Python object in runtime...'), after support JIT Fallback.
with pytest.raises(RuntimeError) as err:
net()
assert "Should not use Python object in runtime" in str(err.value)
def test_np_calculate(): def test_np_calculate():

View File

@ -656,13 +656,10 @@ def test_getattr_numpy_array():
@jit @jit
def foo(): def foo():
x = np.array([1, 2, 3, 4]) x = np.array([1, 2, 3, 4])
# Should work as: return x.shape[0]
return getattr(x, "shape")[0] return getattr(x, "shape")[0]
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0' foo()
with pytest.raises(TypeError) as err:
foo() # Not throw error any more, should move to ST.
assert "Do not support to get attribute" in str(err.value)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
def test_getattr_numpy_array_2(): def test_getattr_numpy_array_2():
@ -677,9 +674,7 @@ def test_getattr_numpy_array_2():
x = 1 x = 1
return getattr(x, "shape", np.array([0, 1, 2, 3, 4])) return getattr(x, "shape", np.array([0, 1, 2, 3, 4]))
with pytest.raises(TypeError) as err: foo()
foo()
assert "For 'getattr', the third input 'default' can not" in str(err.value)
def test_getattr_for_fg_object(): def test_getattr_for_fg_object():