forked from mindspore-Ecosystem/mindspore
!11921 do not broaden when arg is not tensor
From: @zhangbuxue Reviewed-by: Signed-off-by:
This commit is contained in:
commit
cbfba95ad0
|
@ -544,7 +544,7 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (CanSpecializeNode(func)) {
|
if (CanSpecializeNode(func)) {
|
||||||
// for primitive node , we build the primitive node with infered attributes in the first pass
|
// for primitive node , we build the primitive node with inferred attributes in the first pass
|
||||||
// so we do not build replaced node again here in second pass
|
// so we do not build replaced node again here in second pass
|
||||||
if (IsValueNode<Primitive>(func)) {
|
if (IsValueNode<Primitive>(func)) {
|
||||||
new_inputs[0] = func;
|
new_inputs[0] = func;
|
||||||
|
@ -666,14 +666,14 @@ AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin
|
||||||
|
|
||||||
AbstractFunctionPtr abs = dyn_cast<AbstractFunction>(ival);
|
AbstractFunctionPtr abs = dyn_cast<AbstractFunction>(ival);
|
||||||
if (abs != nullptr) {
|
if (abs != nullptr) {
|
||||||
// Cannot build a determinstic ValueNode if there are multiple possible AbstractFunction.
|
// Cannot build a deterministic ValueNode if there are multiple possible AbstractFunction.
|
||||||
if (abs->isa<AbstractFuncUnion>()) {
|
if (abs->isa<AbstractFuncUnion>()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
ValuePtr value = nullptr;
|
ValuePtr value = nullptr;
|
||||||
if (abs->isa<PrimitiveAbstractClosure>()) {
|
if (abs->isa<PrimitiveAbstractClosure>()) {
|
||||||
auto real_fn = dyn_cast<PrimitiveAbstractClosure>(abs);
|
auto real_fn = dyn_cast<PrimitiveAbstractClosure>(abs);
|
||||||
// for primitive, check if the attribute is the same with cnode infererd attribute ,if not, clone a new one
|
// for primitive, check if the attribute is the same with cnode inferred attribute, if not, clone a new one
|
||||||
if (attrs != nullptr) {
|
if (attrs != nullptr) {
|
||||||
value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs);
|
value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -88,7 +88,7 @@ std::string AbstractBase::ToString() const {
|
||||||
return buffer.str();
|
return buffer.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { return AbstractBase::Broaden(config); }
|
AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { return Clone(); }
|
||||||
|
|
||||||
AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
|
AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
|
||||||
MS_EXCEPTION_IF_NULL(other);
|
MS_EXCEPTION_IF_NULL(other);
|
||||||
|
|
|
@ -102,7 +102,7 @@ AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePt
|
||||||
ValuePtr name_value = prim->GetAttr("tag");
|
ValuePtr name_value = prim->GetAttr("tag");
|
||||||
auto name = name_value->cast<StringImmPtr>();
|
auto name = name_value->cast<StringImmPtr>();
|
||||||
if (name == nullptr) {
|
if (name == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "MakeRefKey attr tag sould be a String " << name_value->ToString() << ".";
|
MS_LOG(EXCEPTION) << "MakeRefKey attr tag should be a String " << name_value->ToString() << ".";
|
||||||
}
|
}
|
||||||
auto refkey = std::make_shared<RefKey>(name->value());
|
auto refkey = std::make_shared<RefKey>(name->value());
|
||||||
if (refkey == nullptr) {
|
if (refkey == nullptr) {
|
||||||
|
@ -168,6 +168,9 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p
|
||||||
MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at lest 1, but got 0";
|
MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at lest 1, but got 0";
|
||||||
}
|
}
|
||||||
auto depends = args_spec_list[0]->Broaden();
|
auto depends = args_spec_list[0]->Broaden();
|
||||||
|
if (depends->isa<AbstractScalar>()) {
|
||||||
|
depends->set_value(kAnyValue);
|
||||||
|
}
|
||||||
return depends;
|
return depends;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -182,7 +185,7 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv
|
||||||
auto src_size = arg_src->cast<AbstractTuplePtr>()->size();
|
auto src_size = arg_src->cast<AbstractTuplePtr>()->size();
|
||||||
auto dst_size = arg_src->cast<AbstractTuplePtr>()->size();
|
auto dst_size = arg_src->cast<AbstractTuplePtr>()->size();
|
||||||
if (src_size > 1 && dst_size > 1) {
|
if (src_size > 1 && dst_size > 1) {
|
||||||
MS_LOG(EXCEPTION) << "Control depend can not setup operator dependcy relationship from tuple from tuple";
|
MS_LOG(EXCEPTION) << "Control depend can not setup operator dependency relationship from tuple from tuple";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return std::make_shared<AbstractScalar>(kAnyValue, kBool);
|
return std::make_shared<AbstractScalar>(kAnyValue, kBool);
|
||||||
|
@ -505,7 +508,7 @@ AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePt
|
||||||
auto axis = primitive->GetAttr("axis");
|
auto axis = primitive->GetAttr("axis");
|
||||||
auto value = GetValue<int64_t>(axis);
|
auto value = GetValue<int64_t>(axis);
|
||||||
if (value < -(SizeToInt(x_shape.size()) + 1) || value > SizeToInt(x_shape.size())) {
|
if (value < -(SizeToInt(x_shape.size()) + 1) || value > SizeToInt(x_shape.size())) {
|
||||||
MS_LOG(EXCEPTION) << " axis value shoud be in range [-intput_x.dim-1,input_x.dim], but axis value is" << value
|
MS_LOG(EXCEPTION) << " axis value should be in range [-input_x.dim-1,input_x.dim], but axis value is" << value
|
||||||
<< " and input_x.dim is" << x_shape.size();
|
<< " and input_x.dim is" << x_shape.size();
|
||||||
}
|
}
|
||||||
if (value < 0) {
|
if (value < 0) {
|
||||||
|
|
|
@ -32,26 +32,18 @@ TEST_F(TestUtils, test_join) {
|
||||||
AbstractBasePtr abs_s1 = FromValue(static_cast<int64_t>(1), false);
|
AbstractBasePtr abs_s1 = FromValue(static_cast<int64_t>(1), false);
|
||||||
AbstractBasePtr abs_s2 = FromValue(static_cast<int64_t>(2), false);
|
AbstractBasePtr abs_s2 = FromValue(static_cast<int64_t>(2), false);
|
||||||
AbstractBasePtr abs_s_anything = FromValue(static_cast<int64_t>(2), true);
|
AbstractBasePtr abs_s_anything = FromValue(static_cast<int64_t>(2), true);
|
||||||
|
abs_s_anything->set_value(kAnyValue);
|
||||||
|
|
||||||
AbstractBasePtr res_s1 = abs_s1->Join(abs_s2);
|
AbstractBasePtr res_s1 = abs_s1->Join(abs_s2);
|
||||||
ASSERT_EQ(*res_s1, *abs_s_anything);
|
ASSERT_EQ(*res_s1, *abs_s_anything);
|
||||||
|
|
||||||
// AbstractTuple join;
|
|
||||||
std::vector<int64_t> list1 = {1, 2, 3, 4, 5};
|
|
||||||
std::vector<int64_t> list2 = {5, 4, 3, 2, 1};
|
|
||||||
AbstractBasePtr abs_t1 = FromValue(list1, true);
|
|
||||||
AbstractBasePtr abs_t2 = FromValue(list2, true);
|
|
||||||
|
|
||||||
AbstractBasePtr res_t1 = abs_t1->Join(abs_t2);
|
|
||||||
ASSERT_EQ(res_t1, abs_t1);
|
|
||||||
|
|
||||||
abs_s1 = FromValue(static_cast<int64_t>(1), false);
|
abs_s1 = FromValue(static_cast<int64_t>(1), false);
|
||||||
|
|
||||||
AbstractBasePtr t1 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything}));
|
AbstractBasePtr t1 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything}));
|
||||||
AbstractBasePtr t2 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything}));
|
AbstractBasePtr t2 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything}));
|
||||||
AbstractBasePtr t3 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s_anything, abs_s_anything}));
|
AbstractBasePtr t3 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s_anything, abs_s_anything}));
|
||||||
|
|
||||||
res_t1 = t1->Join(t2);
|
AbstractBasePtr res_t1 = t1->Join(t2);
|
||||||
ASSERT_EQ(res_t1, t1);
|
ASSERT_EQ(res_t1, t1);
|
||||||
|
|
||||||
res_t1 = t1->Join(t3);
|
res_t1 = t1->Join(t3);
|
||||||
|
|
|
@ -111,8 +111,11 @@ TEST_F(TestOptLib, test_inline) {
|
||||||
// add infer and renormalize
|
// add infer and renormalize
|
||||||
std::shared_ptr<mindspore::pipeline::Resource> res = std::make_shared<mindspore::pipeline::Resource>();
|
std::shared_ptr<mindspore::pipeline::Resource> res = std::make_shared<mindspore::pipeline::Resource>();
|
||||||
AbstractBasePtrList args_spec_list;
|
AbstractBasePtrList args_spec_list;
|
||||||
AbstractBasePtr abstract_v1 = abstract::FromValue(static_cast<int64_t>(1), true);
|
tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{2, 3});
|
||||||
AbstractBasePtr abstract_v2 = abstract::FromValue(static_cast<int64_t>(2), true);
|
tensor::TensorPtr y_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{2, 3});
|
||||||
|
|
||||||
|
AbstractBasePtr abstract_v1 = abstract::FromValue(x_tensor, true);
|
||||||
|
AbstractBasePtr abstract_v2 = abstract::FromValue(y_tensor, true);
|
||||||
args_spec_list.push_back(abstract_v1);
|
args_spec_list.push_back(abstract_v1);
|
||||||
args_spec_list.push_back(abstract_v2);
|
args_spec_list.push_back(abstract_v2);
|
||||||
AnalysisResult result = pipeline::AbstractAnalyze(res, before1, args_spec_list);
|
AnalysisResult result = pipeline::AbstractAnalyze(res, before1, args_spec_list);
|
||||||
|
|
|
@ -184,7 +184,7 @@ TEST_F(TestData, test_broaden) {
|
||||||
AbstractBasePtr s2 = s1->Broaden();
|
AbstractBasePtr s2 = s1->Broaden();
|
||||||
ASSERT_TRUE(*s1->GetTypeTrack() == *s2->GetTypeTrack());
|
ASSERT_TRUE(*s1->GetTypeTrack() == *s2->GetTypeTrack());
|
||||||
ASSERT_TRUE(*s1->GetValueTrack() == *MakeValue(int1));
|
ASSERT_TRUE(*s1->GetValueTrack() == *MakeValue(int1));
|
||||||
ASSERT_TRUE(s2->GetValueTrack()->isa<AnyValue>());
|
ASSERT_TRUE(s2->GetValueTrack()->isa<Int64Imm>());
|
||||||
|
|
||||||
AbstractFunctionPtr f1 = std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(),
|
AbstractFunctionPtr f1 = std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(),
|
||||||
AnalysisContext::DummyContext());
|
AnalysisContext::DummyContext());
|
||||||
|
@ -196,7 +196,7 @@ TEST_F(TestData, test_broaden) {
|
||||||
AbstractList* l2_cast = dynamic_cast<AbstractList*>(l2.get());
|
AbstractList* l2_cast = dynamic_cast<AbstractList*>(l2.get());
|
||||||
ASSERT_TRUE(l2_cast != nullptr);
|
ASSERT_TRUE(l2_cast != nullptr);
|
||||||
AbstractBasePtr csr = AbstractJoin(l2_cast->elements());
|
AbstractBasePtr csr = AbstractJoin(l2_cast->elements());
|
||||||
ASSERT_TRUE(csr->GetValueTrack()->isa<AnyValue>());
|
ASSERT_TRUE(csr->GetValueTrack()->isa<Int64Imm>());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
|
|
|
@ -761,27 +761,6 @@ def test_while_scalar():
|
||||||
out = net(x, y)
|
out = net(x, y)
|
||||||
|
|
||||||
|
|
||||||
def test_while_tensor():
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(Net, self).__init__()
|
|
||||||
self.t = Tensor(np.ones([6, 8, 10], np.int32))
|
|
||||||
self.count = Tensor(np.array([10], np.int32))
|
|
||||||
|
|
||||||
def construct(self, x, y):
|
|
||||||
i = 0
|
|
||||||
t = self.t
|
|
||||||
while (i < self.count):
|
|
||||||
t = t + x + y
|
|
||||||
i = i + 1
|
|
||||||
return t
|
|
||||||
|
|
||||||
net = Net()
|
|
||||||
x = Tensor(np.ones([6, 8, 10], np.int32))
|
|
||||||
y = Tensor(np.ones([6, 8, 10], np.int32))
|
|
||||||
out = net(x, y)
|
|
||||||
|
|
||||||
|
|
||||||
def test_large_for_loop():
|
def test_large_for_loop():
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -20,6 +20,7 @@ from mindspore import Tensor, Parameter
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import dtype as mstype
|
from mindspore import dtype as mstype
|
||||||
from mindspore.nn import Cell
|
from mindspore.nn import Cell
|
||||||
|
from mindspore.ops import operations as P
|
||||||
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
||||||
from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
||||||
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \
|
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \
|
||||||
|
@ -683,6 +684,27 @@ def test_tensor_assign_bool_index():
|
||||||
net4(Ta, Tensor(u_scalar, mstype.int32))
|
net4(Ta, Tensor(u_scalar, mstype.int32))
|
||||||
|
|
||||||
|
|
||||||
|
def test_trivial_call_function_twice_with_diff_key_value_para():
|
||||||
|
class Net(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.arange = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))
|
||||||
|
self.concat = P.Concat(axis=0)
|
||||||
|
|
||||||
|
def compute(self, x, is_decoder):
|
||||||
|
if is_decoder:
|
||||||
|
return self.arange[:x]
|
||||||
|
return self.arange[1:x + 1]
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
result1 = self.compute(7, is_decoder=True)
|
||||||
|
result2 = self.compute(6, is_decoder=False)
|
||||||
|
return self.concat((result1, result2))
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
net()
|
||||||
|
|
||||||
|
|
||||||
test_cases = [
|
test_cases = [
|
||||||
('TensorAssignWithTupleEllipsis2', {
|
('TensorAssignWithTupleEllipsis2', {
|
||||||
'block': TensorAssignWithTupleEllipsis2(),
|
'block': TensorAssignWithTupleEllipsis2(),
|
||||||
|
|
Loading…
Reference in New Issue