Add arguments checking for AddN primitive;
In JIT Fallback, getattr(interpreted_obj, 'attr') should work as interpred_obj.attr; Support Return InterpretedNode value; Not throw exception if getattr's default value is InterpretedNode;
This commit is contained in:
parent
03ac6b3a37
commit
46e9e37d1e
|
@ -33,6 +33,7 @@
|
||||||
#include "frontend/operator/composite/composite.h"
|
#include "frontend/operator/composite/composite.h"
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "ir/value.h"
|
#include "ir/value.h"
|
||||||
|
#include "pipeline/jit/fallback.h"
|
||||||
#include "pipeline/jit/parse/resolve.h"
|
#include "pipeline/jit/parse/resolve.h"
|
||||||
#include "utils/hash_map.h"
|
#include "utils/hash_map.h"
|
||||||
#include "utils/anf_utils.h"
|
#include "utils/anf_utils.h"
|
||||||
|
@ -1059,10 +1060,19 @@ class CleanAfterOptARewriter : public BaseRewriter {
|
||||||
return make_dict_node;
|
return make_dict_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AnfNodePtr ConvertInterpretedObjectValue(const ValueNodePtr &node, const parse::InterpretedObjectPtr &value) {
|
||||||
|
// Convert InterpretedObject value node to PyExecute CNode.
|
||||||
|
return ConvertInterpretedObjectToPyExecute(root_graph_, value, node);
|
||||||
|
}
|
||||||
|
|
||||||
AnfNodePtr ConvertValueNode(const ValueNodePtr &value_node, const ValuePtr &value) override {
|
AnfNodePtr ConvertValueNode(const ValueNodePtr &value_node, const ValuePtr &value) override {
|
||||||
const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
|
const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
|
||||||
if (value->isa<ValueDictionary>() && support_fallback_runtime) {
|
if (support_fallback_runtime) {
|
||||||
return RebuildValueDict(value_node, value->cast<ValueDictionaryPtr>());
|
if (value->isa<ValueDictionary>()) {
|
||||||
|
return RebuildValueDict(value_node, value->cast<ValueDictionaryPtr>());
|
||||||
|
} else if (value->isa<parse::InterpretedObject>()) {
|
||||||
|
return ConvertInterpretedObjectValue(value_node, value->cast<parse::InterpretedObjectPtr>());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
bool need_convert = false;
|
bool need_convert = false;
|
||||||
auto convert_value = ConvertValueSequenceToValueTuple(value, 0, &need_convert);
|
auto convert_value = ConvertValueSequenceToValueTuple(value, 0, &need_convert);
|
||||||
|
|
|
@ -1885,7 +1885,7 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr auto max_args_size = 3;
|
constexpr auto max_args_size = 3;
|
||||||
if (args_abs_list.size() == max_args_size) {
|
if (!support_fallback_runtime && args_abs_list.size() == max_args_size) {
|
||||||
constexpr size_t default_index = 2;
|
constexpr size_t default_index = 2;
|
||||||
auto default_args = args_abs_list[default_index];
|
auto default_args = args_abs_list[default_index];
|
||||||
if (default_args->isa<abstract::AbstractScalar>()) {
|
if (default_args->isa<abstract::AbstractScalar>()) {
|
||||||
|
|
|
@ -64,11 +64,16 @@ bool AddNDynShapeJoin(ShapeVector *shape1, const ShapeVector *shape2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract::ShapePtr AddNInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
abstract::ShapePtr AddNInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
if (!input_args[0]->isa<abstract::AbstractTuple>() && !input_args[0]->isa<abstract::AbstractList>()) {
|
||||||
|
MS_EXCEPTION(TypeError) << "For '" << prim_name
|
||||||
|
<< "', the input data type must be list or tuple of tensors.But got:"
|
||||||
|
<< input_args[0]->ToString();
|
||||||
|
}
|
||||||
auto elements = input_args[0]->isa<abstract::AbstractTuple>()
|
auto elements = input_args[0]->isa<abstract::AbstractTuple>()
|
||||||
? input_args[0]->cast<abstract::AbstractTuplePtr>()->elements()
|
? input_args[0]->cast<abstract::AbstractTuplePtr>()->elements()
|
||||||
: input_args[0]->cast<abstract::AbstractListPtr>()->elements();
|
: input_args[0]->cast<abstract::AbstractListPtr>()->elements();
|
||||||
(void)CheckAndConvertUtils::CheckInteger("input num", SizeToLong(elements.size()), kGreaterEqual, 1,
|
(void)CheckAndConvertUtils::CheckInteger("input num", SizeToLong(elements.size()), kGreaterEqual, 1, prim_name);
|
||||||
primitive->name());
|
|
||||||
(void)primitive->AddAttr("N", MakeValue(SizeToLong(elements.size())));
|
(void)primitive->AddAttr("N", MakeValue(SizeToLong(elements.size())));
|
||||||
(void)primitive->AddAttr("n", MakeValue(SizeToLong(elements.size())));
|
(void)primitive->AddAttr("n", MakeValue(SizeToLong(elements.size())));
|
||||||
auto shape_0 = elements[0]->BuildShape();
|
auto shape_0 = elements[0]->BuildShape();
|
||||||
|
@ -91,21 +96,25 @@ abstract::ShapePtr AddNInferShape(const PrimitivePtr &primitive, const std::vect
|
||||||
}
|
}
|
||||||
// Join input[i] with input[0]
|
// Join input[i] with input[0]
|
||||||
if (!AddNDynShapeJoin(&output_shape, &shape_vec)) {
|
if (!AddNDynShapeJoin(&output_shape, &shape_vec)) {
|
||||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', input shape must be same, but got shape of input["
|
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', input shape must be same, but got shape of input[" << i
|
||||||
<< i << "]: " << shape->ToString() << ", shape of input[0]: " << shape_0->ToString()
|
<< "]: " << shape->ToString() << ", shape of input[0]: " << shape_0->ToString() << ".";
|
||||||
<< ".";
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return std::make_shared<abstract::Shape>(output_shape);
|
return std::make_shared<abstract::Shape>(output_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr AddNInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr AddNInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
auto prim_name = prim->name();
|
auto prim_name = primitive->name();
|
||||||
|
if (!input_args[0]->isa<abstract::AbstractTuple>() && !input_args[0]->isa<abstract::AbstractList>()) {
|
||||||
|
MS_EXCEPTION(TypeError) << "For '" << prim_name
|
||||||
|
<< "', the input data type must be list or tuple of tensors.But got:"
|
||||||
|
<< input_args[0]->ToString();
|
||||||
|
}
|
||||||
auto elements = input_args[0]->isa<abstract::AbstractTuple>()
|
auto elements = input_args[0]->isa<abstract::AbstractTuple>()
|
||||||
? input_args[0]->cast<abstract::AbstractTuplePtr>()->elements()
|
? input_args[0]->cast<abstract::AbstractTuplePtr>()->elements()
|
||||||
: input_args[0]->cast<abstract::AbstractListPtr>()->elements();
|
: input_args[0]->cast<abstract::AbstractListPtr>()->elements();
|
||||||
(void)CheckAndConvertUtils::CheckInteger("concat element num", SizeToLong(elements.size()), kGreaterEqual, 1,
|
(void)CheckAndConvertUtils::CheckInteger("concat element num", SizeToLong(elements.size()), kGreaterEqual, 1,
|
||||||
prim->name());
|
prim_name);
|
||||||
std::map<std::string, TypePtr> types;
|
std::map<std::string, TypePtr> types;
|
||||||
(void)types.emplace("element_0", elements[0]->BuildType());
|
(void)types.emplace("element_0", elements[0]->BuildType());
|
||||||
for (size_t i = 0; i < elements.size(); ++i) {
|
for (size_t i = 0; i < elements.size(); ++i) {
|
||||||
|
|
|
@ -112,7 +112,7 @@ _unsupported_internal_type = (
|
||||||
)
|
)
|
||||||
|
|
||||||
_hybrid_type = (
|
_hybrid_type = (
|
||||||
print, enumerate, zip, map, filter, abs, round, max, min, sum, hasattr, list, tuple
|
print, enumerate, zip, map, filter, abs, round, max, min, sum, getattr, hasattr, list, tuple
|
||||||
)
|
)
|
||||||
|
|
||||||
# Unsupported python builtin type in JIT Fallback.
|
# Unsupported python builtin type in JIT Fallback.
|
||||||
|
|
|
@ -13,7 +13,6 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
""" test_parse_numpy """
|
""" test_parse_numpy """
|
||||||
import pytest
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore import nn, context, jit, Tensor
|
from mindspore import nn, context, jit, Tensor
|
||||||
|
|
||||||
|
@ -44,11 +43,7 @@ def test_use_numpy_method():
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
net = Net()
|
net = Net()
|
||||||
# Not raise NotImplementedError('Mindspore not supports to use the numpy ...') any more,
|
net()
|
||||||
# but raise RuntimeError('Should not use Python object in runtime...'), after support JIT Fallback.
|
|
||||||
with pytest.raises(RuntimeError) as err:
|
|
||||||
net()
|
|
||||||
assert "Should not use Python object in runtime" in str(err.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_use_numpy_module():
|
def test_use_numpy_module():
|
||||||
|
@ -61,11 +56,7 @@ def test_use_numpy_module():
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
net = Net()
|
net = Net()
|
||||||
# Not raise NotImplementedError('Mindspore not supports to use the numpy ...') any more,
|
net()
|
||||||
# but raise RuntimeError('Should not use Python object in runtime...'), after support JIT Fallback.
|
|
||||||
with pytest.raises(RuntimeError) as err:
|
|
||||||
net()
|
|
||||||
assert "Should not use Python object in runtime" in str(err.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_np_calculate():
|
def test_np_calculate():
|
||||||
|
|
|
@ -656,13 +656,10 @@ def test_getattr_numpy_array():
|
||||||
@jit
|
@jit
|
||||||
def foo():
|
def foo():
|
||||||
x = np.array([1, 2, 3, 4])
|
x = np.array([1, 2, 3, 4])
|
||||||
|
# Should work as: return x.shape[0]
|
||||||
return getattr(x, "shape")[0]
|
return getattr(x, "shape")[0]
|
||||||
|
|
||||||
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
|
foo()
|
||||||
with pytest.raises(TypeError) as err:
|
|
||||||
foo() # Not throw error any more, should move to ST.
|
|
||||||
assert "Do not support to get attribute" in str(err.value)
|
|
||||||
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
|
|
||||||
|
|
||||||
|
|
||||||
def test_getattr_numpy_array_2():
|
def test_getattr_numpy_array_2():
|
||||||
|
@ -677,9 +674,7 @@ def test_getattr_numpy_array_2():
|
||||||
x = 1
|
x = 1
|
||||||
return getattr(x, "shape", np.array([0, 1, 2, 3, 4]))
|
return getattr(x, "shape", np.array([0, 1, 2, 3, 4]))
|
||||||
|
|
||||||
with pytest.raises(TypeError) as err:
|
foo()
|
||||||
foo()
|
|
||||||
assert "For 'getattr', the third input 'default' can not" in str(err.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_getattr_for_fg_object():
|
def test_getattr_for_fg_object():
|
||||||
|
|
Loading…
Reference in New Issue