!11921 do not broaden when arg is not tensor

From: @zhangbuxue
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-03 17:07:47 +08:00 committed by Gitee
commit cbfba95ad0
8 changed files with 41 additions and 42 deletions

View File

@ -544,7 +544,7 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
} }
if (CanSpecializeNode(func)) { if (CanSpecializeNode(func)) {
// for primitive node , we build the primitive node with infered attributes in the first pass // for primitive node , we build the primitive node with inferred attributes in the first pass
// so we do not build replaced node again here in second pass // so we do not build replaced node again here in second pass
if (IsValueNode<Primitive>(func)) { if (IsValueNode<Primitive>(func)) {
new_inputs[0] = func; new_inputs[0] = func;
@ -666,14 +666,14 @@ AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin
AbstractFunctionPtr abs = dyn_cast<AbstractFunction>(ival); AbstractFunctionPtr abs = dyn_cast<AbstractFunction>(ival);
if (abs != nullptr) { if (abs != nullptr) {
// Cannot build a determinstic ValueNode if there are multiple possible AbstractFunction. // Cannot build a deterministic ValueNode if there are multiple possible AbstractFunction.
if (abs->isa<AbstractFuncUnion>()) { if (abs->isa<AbstractFuncUnion>()) {
return nullptr; return nullptr;
} }
ValuePtr value = nullptr; ValuePtr value = nullptr;
if (abs->isa<PrimitiveAbstractClosure>()) { if (abs->isa<PrimitiveAbstractClosure>()) {
auto real_fn = dyn_cast<PrimitiveAbstractClosure>(abs); auto real_fn = dyn_cast<PrimitiveAbstractClosure>(abs);
// for primitive, check if the attribute is the same with cnode infererd attribute ,if not, clone a new one // for primitive, check if the attribute is the same with cnode inferred attribute, if not, clone a new one
if (attrs != nullptr) { if (attrs != nullptr) {
value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs); value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs);
} else { } else {

View File

@ -88,7 +88,7 @@ std::string AbstractBase::ToString() const {
return buffer.str(); return buffer.str();
} }
AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { return AbstractBase::Broaden(config); } AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { return Clone(); }
AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
MS_EXCEPTION_IF_NULL(other); MS_EXCEPTION_IF_NULL(other);

View File

@ -102,7 +102,7 @@ AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePt
ValuePtr name_value = prim->GetAttr("tag"); ValuePtr name_value = prim->GetAttr("tag");
auto name = name_value->cast<StringImmPtr>(); auto name = name_value->cast<StringImmPtr>();
if (name == nullptr) { if (name == nullptr) {
MS_LOG(EXCEPTION) << "MakeRefKey attr tag sould be a String " << name_value->ToString() << "."; MS_LOG(EXCEPTION) << "MakeRefKey attr tag should be a String " << name_value->ToString() << ".";
} }
auto refkey = std::make_shared<RefKey>(name->value()); auto refkey = std::make_shared<RefKey>(name->value());
if (refkey == nullptr) { if (refkey == nullptr) {
@ -168,6 +168,9 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p
MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at lest 1, but got 0"; MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at lest 1, but got 0";
} }
auto depends = args_spec_list[0]->Broaden(); auto depends = args_spec_list[0]->Broaden();
if (depends->isa<AbstractScalar>()) {
depends->set_value(kAnyValue);
}
return depends; return depends;
} }
@ -182,7 +185,7 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv
auto src_size = arg_src->cast<AbstractTuplePtr>()->size(); auto src_size = arg_src->cast<AbstractTuplePtr>()->size();
auto dst_size = arg_src->cast<AbstractTuplePtr>()->size(); auto dst_size = arg_src->cast<AbstractTuplePtr>()->size();
if (src_size > 1 && dst_size > 1) { if (src_size > 1 && dst_size > 1) {
MS_LOG(EXCEPTION) << "Control depend can not setup operator dependcy relationship from tuple from tuple"; MS_LOG(EXCEPTION) << "Control depend can not setup operator dependency relationship from tuple from tuple";
} }
} }
return std::make_shared<AbstractScalar>(kAnyValue, kBool); return std::make_shared<AbstractScalar>(kAnyValue, kBool);
@ -505,7 +508,7 @@ AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePt
auto axis = primitive->GetAttr("axis"); auto axis = primitive->GetAttr("axis");
auto value = GetValue<int64_t>(axis); auto value = GetValue<int64_t>(axis);
if (value < -(SizeToInt(x_shape.size()) + 1) || value > SizeToInt(x_shape.size())) { if (value < -(SizeToInt(x_shape.size()) + 1) || value > SizeToInt(x_shape.size())) {
MS_LOG(EXCEPTION) << " axis value shoud be in range [-intput_x.dim-1,input_x.dim], but axis value is" << value MS_LOG(EXCEPTION) << " axis value should be in range [-input_x.dim-1,input_x.dim], but axis value is" << value
<< " and input_x.dim is" << x_shape.size(); << " and input_x.dim is" << x_shape.size();
} }
if (value < 0) { if (value < 0) {

View File

@ -32,26 +32,18 @@ TEST_F(TestUtils, test_join) {
AbstractBasePtr abs_s1 = FromValue(static_cast<int64_t>(1), false); AbstractBasePtr abs_s1 = FromValue(static_cast<int64_t>(1), false);
AbstractBasePtr abs_s2 = FromValue(static_cast<int64_t>(2), false); AbstractBasePtr abs_s2 = FromValue(static_cast<int64_t>(2), false);
AbstractBasePtr abs_s_anything = FromValue(static_cast<int64_t>(2), true); AbstractBasePtr abs_s_anything = FromValue(static_cast<int64_t>(2), true);
abs_s_anything->set_value(kAnyValue);
AbstractBasePtr res_s1 = abs_s1->Join(abs_s2); AbstractBasePtr res_s1 = abs_s1->Join(abs_s2);
ASSERT_EQ(*res_s1, *abs_s_anything); ASSERT_EQ(*res_s1, *abs_s_anything);
// AbstractTuple join;
std::vector<int64_t> list1 = {1, 2, 3, 4, 5};
std::vector<int64_t> list2 = {5, 4, 3, 2, 1};
AbstractBasePtr abs_t1 = FromValue(list1, true);
AbstractBasePtr abs_t2 = FromValue(list2, true);
AbstractBasePtr res_t1 = abs_t1->Join(abs_t2);
ASSERT_EQ(res_t1, abs_t1);
abs_s1 = FromValue(static_cast<int64_t>(1), false); abs_s1 = FromValue(static_cast<int64_t>(1), false);
AbstractBasePtr t1 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything})); AbstractBasePtr t1 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything}));
AbstractBasePtr t2 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything})); AbstractBasePtr t2 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything}));
AbstractBasePtr t3 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s_anything, abs_s_anything})); AbstractBasePtr t3 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s_anything, abs_s_anything}));
res_t1 = t1->Join(t2); AbstractBasePtr res_t1 = t1->Join(t2);
ASSERT_EQ(res_t1, t1); ASSERT_EQ(res_t1, t1);
res_t1 = t1->Join(t3); res_t1 = t1->Join(t3);

View File

@ -111,8 +111,11 @@ TEST_F(TestOptLib, test_inline) {
// add infer and renormalize // add infer and renormalize
std::shared_ptr<mindspore::pipeline::Resource> res = std::make_shared<mindspore::pipeline::Resource>(); std::shared_ptr<mindspore::pipeline::Resource> res = std::make_shared<mindspore::pipeline::Resource>();
AbstractBasePtrList args_spec_list; AbstractBasePtrList args_spec_list;
AbstractBasePtr abstract_v1 = abstract::FromValue(static_cast<int64_t>(1), true); tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{2, 3});
AbstractBasePtr abstract_v2 = abstract::FromValue(static_cast<int64_t>(2), true); tensor::TensorPtr y_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{2, 3});
AbstractBasePtr abstract_v1 = abstract::FromValue(x_tensor, true);
AbstractBasePtr abstract_v2 = abstract::FromValue(y_tensor, true);
args_spec_list.push_back(abstract_v1); args_spec_list.push_back(abstract_v1);
args_spec_list.push_back(abstract_v2); args_spec_list.push_back(abstract_v2);
AnalysisResult result = pipeline::AbstractAnalyze(res, before1, args_spec_list); AnalysisResult result = pipeline::AbstractAnalyze(res, before1, args_spec_list);

View File

@ -184,7 +184,7 @@ TEST_F(TestData, test_broaden) {
AbstractBasePtr s2 = s1->Broaden(); AbstractBasePtr s2 = s1->Broaden();
ASSERT_TRUE(*s1->GetTypeTrack() == *s2->GetTypeTrack()); ASSERT_TRUE(*s1->GetTypeTrack() == *s2->GetTypeTrack());
ASSERT_TRUE(*s1->GetValueTrack() == *MakeValue(int1)); ASSERT_TRUE(*s1->GetValueTrack() == *MakeValue(int1));
ASSERT_TRUE(s2->GetValueTrack()->isa<AnyValue>()); ASSERT_TRUE(s2->GetValueTrack()->isa<Int64Imm>());
AbstractFunctionPtr f1 = std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(), AbstractFunctionPtr f1 = std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(),
AnalysisContext::DummyContext()); AnalysisContext::DummyContext());
@ -196,7 +196,7 @@ TEST_F(TestData, test_broaden) {
AbstractList* l2_cast = dynamic_cast<AbstractList*>(l2.get()); AbstractList* l2_cast = dynamic_cast<AbstractList*>(l2.get());
ASSERT_TRUE(l2_cast != nullptr); ASSERT_TRUE(l2_cast != nullptr);
AbstractBasePtr csr = AbstractJoin(l2_cast->elements()); AbstractBasePtr csr = AbstractJoin(l2_cast->elements());
ASSERT_TRUE(csr->GetValueTrack()->isa<AnyValue>()); ASSERT_TRUE(csr->GetValueTrack()->isa<Int64Imm>());
} }
} // namespace abstract } // namespace abstract

View File

@ -761,27 +761,6 @@ def test_while_scalar():
out = net(x, y) out = net(x, y)
def test_while_tensor():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.t = Tensor(np.ones([6, 8, 10], np.int32))
self.count = Tensor(np.array([10], np.int32))
def construct(self, x, y):
i = 0
t = self.t
while (i < self.count):
t = t + x + y
i = i + 1
return t
net = Net()
x = Tensor(np.ones([6, 8, 10], np.int32))
y = Tensor(np.ones([6, 8, 10], np.int32))
out = net(x, y)
def test_large_for_loop(): def test_large_for_loop():
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):

View File

@ -20,6 +20,7 @@ from mindspore import Tensor, Parameter
from mindspore import context from mindspore import context
from mindspore import dtype as mstype from mindspore import dtype as mstype
from mindspore.nn import Cell from mindspore.nn import Cell
from mindspore.ops import operations as P
from ....mindspore_test_framework.mindspore_test import mindspore_test from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \ from ....mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \ import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \
@ -683,6 +684,27 @@ def test_tensor_assign_bool_index():
net4(Ta, Tensor(u_scalar, mstype.int32)) net4(Ta, Tensor(u_scalar, mstype.int32))
def test_trivial_call_function_twice_with_diff_key_value_para():
class Net(Cell):
def __init__(self):
super(Net, self).__init__()
self.arange = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))
self.concat = P.Concat(axis=0)
def compute(self, x, is_decoder):
if is_decoder:
return self.arange[:x]
return self.arange[1:x + 1]
def construct(self):
result1 = self.compute(7, is_decoder=True)
result2 = self.compute(6, is_decoder=False)
return self.concat((result1, result2))
net = Net()
net()
test_cases = [ test_cases = [
('TensorAssignWithTupleEllipsis2', { ('TensorAssignWithTupleEllipsis2', {
'block': TensorAssignWithTupleEllipsis2(), 'block': TensorAssignWithTupleEllipsis2(),