do not broaden scalar

This commit is contained in:
buxue 2021-01-29 18:18:14 +08:00
parent 9557bef491
commit 6ccc4379b4
8 changed files with 41 additions and 42 deletions

View File

@ -544,7 +544,7 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
}
if (CanSpecializeNode(func)) {
// for primitive node , we build the primitive node with infered attributes in the first pass
// for primitive node , we build the primitive node with inferred attributes in the first pass
// so we do not build replaced node again here in second pass
if (IsValueNode<Primitive>(func)) {
new_inputs[0] = func;
@ -666,14 +666,14 @@ AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin
AbstractFunctionPtr abs = dyn_cast<AbstractFunction>(ival);
if (abs != nullptr) {
// Cannot build a determinstic ValueNode if there are multiple possible AbstractFunction.
// Cannot build a deterministic ValueNode if there are multiple possible AbstractFunction.
if (abs->isa<AbstractFuncUnion>()) {
return nullptr;
}
ValuePtr value = nullptr;
if (abs->isa<PrimitiveAbstractClosure>()) {
auto real_fn = dyn_cast<PrimitiveAbstractClosure>(abs);
// for primitive, check if the attribute is the same with cnode infererd attribute ,if not, clone a new one
// for primitive, check if the attribute is the same with cnode inferred attribute, if not, clone a new one
if (attrs != nullptr) {
value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs);
} else {

View File

@ -88,7 +88,7 @@ std::string AbstractBase::ToString() const {
return buffer.str();
}
AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { return AbstractBase::Broaden(config); }
AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { return Clone(); }
AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
MS_EXCEPTION_IF_NULL(other);

View File

@ -102,7 +102,7 @@ AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePt
ValuePtr name_value = prim->GetAttr("tag");
auto name = name_value->cast<StringImmPtr>();
if (name == nullptr) {
MS_LOG(EXCEPTION) << "MakeRefKey attr tag sould be a String " << name_value->ToString() << ".";
MS_LOG(EXCEPTION) << "MakeRefKey attr tag should be a String " << name_value->ToString() << ".";
}
auto refkey = std::make_shared<RefKey>(name->value());
if (refkey == nullptr) {
@ -168,6 +168,9 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p
MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at lest 1, but got 0";
}
auto depends = args_spec_list[0]->Broaden();
if (depends->isa<AbstractScalar>()) {
depends->set_value(kAnyValue);
}
return depends;
}
@ -182,7 +185,7 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv
auto src_size = arg_src->cast<AbstractTuplePtr>()->size();
auto dst_size = arg_src->cast<AbstractTuplePtr>()->size();
if (src_size > 1 && dst_size > 1) {
MS_LOG(EXCEPTION) << "Control depend can not setup operator dependcy relationship from tuple from tuple";
MS_LOG(EXCEPTION) << "Control depend can not setup operator dependency relationship from tuple from tuple";
}
}
return std::make_shared<AbstractScalar>(kAnyValue, kBool);
@ -505,7 +508,7 @@ AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePt
auto axis = primitive->GetAttr("axis");
auto value = GetValue<int64_t>(axis);
if (value < -(SizeToInt(x_shape.size()) + 1) || value > SizeToInt(x_shape.size())) {
MS_LOG(EXCEPTION) << " axis value shoud be in range [-intput_x.dim-1,input_x.dim], but axis value is" << value
MS_LOG(EXCEPTION) << " axis value should be in range [-input_x.dim-1,input_x.dim], but axis value is" << value
<< " and input_x.dim is" << x_shape.size();
}
if (value < 0) {

View File

@ -32,26 +32,18 @@ TEST_F(TestUtils, test_join) {
AbstractBasePtr abs_s1 = FromValue(static_cast<int64_t>(1), false);
AbstractBasePtr abs_s2 = FromValue(static_cast<int64_t>(2), false);
AbstractBasePtr abs_s_anything = FromValue(static_cast<int64_t>(2), true);
abs_s_anything->set_value(kAnyValue);
AbstractBasePtr res_s1 = abs_s1->Join(abs_s2);
ASSERT_EQ(*res_s1, *abs_s_anything);
// AbstractTuple join;
std::vector<int64_t> list1 = {1, 2, 3, 4, 5};
std::vector<int64_t> list2 = {5, 4, 3, 2, 1};
AbstractBasePtr abs_t1 = FromValue(list1, true);
AbstractBasePtr abs_t2 = FromValue(list2, true);
AbstractBasePtr res_t1 = abs_t1->Join(abs_t2);
ASSERT_EQ(res_t1, abs_t1);
abs_s1 = FromValue(static_cast<int64_t>(1), false);
AbstractBasePtr t1 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything}));
AbstractBasePtr t2 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything}));
AbstractBasePtr t3 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s_anything, abs_s_anything}));
res_t1 = t1->Join(t2);
AbstractBasePtr res_t1 = t1->Join(t2);
ASSERT_EQ(res_t1, t1);
res_t1 = t1->Join(t3);

View File

@ -111,8 +111,11 @@ TEST_F(TestOptLib, test_inline) {
// add infer and renormalize
std::shared_ptr<mindspore::pipeline::Resource> res = std::make_shared<mindspore::pipeline::Resource>();
AbstractBasePtrList args_spec_list;
AbstractBasePtr abstract_v1 = abstract::FromValue(static_cast<int64_t>(1), true);
AbstractBasePtr abstract_v2 = abstract::FromValue(static_cast<int64_t>(2), true);
tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{2, 3});
tensor::TensorPtr y_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{2, 3});
AbstractBasePtr abstract_v1 = abstract::FromValue(x_tensor, true);
AbstractBasePtr abstract_v2 = abstract::FromValue(y_tensor, true);
args_spec_list.push_back(abstract_v1);
args_spec_list.push_back(abstract_v2);
AnalysisResult result = pipeline::AbstractAnalyze(res, before1, args_spec_list);

View File

@ -184,7 +184,7 @@ TEST_F(TestData, test_broaden) {
AbstractBasePtr s2 = s1->Broaden();
ASSERT_TRUE(*s1->GetTypeTrack() == *s2->GetTypeTrack());
ASSERT_TRUE(*s1->GetValueTrack() == *MakeValue(int1));
ASSERT_TRUE(s2->GetValueTrack()->isa<AnyValue>());
ASSERT_TRUE(s2->GetValueTrack()->isa<Int64Imm>());
AbstractFunctionPtr f1 = std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(),
AnalysisContext::DummyContext());
@ -196,7 +196,7 @@ TEST_F(TestData, test_broaden) {
AbstractList* l2_cast = dynamic_cast<AbstractList*>(l2.get());
ASSERT_TRUE(l2_cast != nullptr);
AbstractBasePtr csr = AbstractJoin(l2_cast->elements());
ASSERT_TRUE(csr->GetValueTrack()->isa<AnyValue>());
ASSERT_TRUE(csr->GetValueTrack()->isa<Int64Imm>());
}
} // namespace abstract

View File

@ -761,27 +761,6 @@ def test_while_scalar():
out = net(x, y)
def test_while_tensor():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.t = Tensor(np.ones([6, 8, 10], np.int32))
self.count = Tensor(np.array([10], np.int32))
def construct(self, x, y):
i = 0
t = self.t
while (i < self.count):
t = t + x + y
i = i + 1
return t
net = Net()
x = Tensor(np.ones([6, 8, 10], np.int32))
y = Tensor(np.ones([6, 8, 10], np.int32))
out = net(x, y)
def test_large_for_loop():
class Net(nn.Cell):
def __init__(self):

View File

@ -20,6 +20,7 @@ from mindspore import Tensor, Parameter
from mindspore import context
from mindspore import dtype as mstype
from mindspore.nn import Cell
from mindspore.ops import operations as P
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \
@ -683,6 +684,27 @@ def test_tensor_assign_bool_index():
net4(Ta, Tensor(u_scalar, mstype.int32))
def test_trivial_call_function_twice_with_diff_key_value_para():
class Net(Cell):
def __init__(self):
super(Net, self).__init__()
self.arange = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))
self.concat = P.Concat(axis=0)
def compute(self, x, is_decoder):
if is_decoder:
return self.arange[:x]
return self.arange[1:x + 1]
def construct(self):
result1 = self.compute(7, is_decoder=True)
result2 = self.compute(6, is_decoder=False)
return self.concat((result1, result2))
net = Net()
net()
test_cases = [
('TensorAssignWithTupleEllipsis2', {
'block': TensorAssignWithTupleEllipsis2(),