support scalar input for cell

This commit is contained in:
yujianfeng 2021-02-07 17:04:26 +08:00
parent 5224241ca7
commit 41189781f3
25 changed files with 192 additions and 183 deletions

View File

@ -453,7 +453,7 @@ class Validator:
return padding
@staticmethod
def check_subclass(arg_name, type_, template_types, prim_name):
def check_subclass(arg_name, type_, template_types, prim_name, addition_error_info=None):
"""Checks whether some type is subclass of another type"""
if not isinstance(template_types, Iterable):
template_types = (template_types,)
@ -467,9 +467,12 @@ class Validator:
hit = True
break
if not hit:
if addition_error_info is None:
addition_error_info = ''
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
raise TypeError(f'For \'{prim_name}\', the type of `{arg_name}` should be subclass'
f' of {", ".join((str(x) for x in template_types))}, but got {type_str}.')
f' of {", ".join((str(x) for x in template_types))}, but got {type_str}.'
f' {addition_error_info}')
@staticmethod
def check_const_input(arg_name, arg_value, prim_name):

View File

@ -401,7 +401,9 @@ FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &
}
if (tail_type_ == kGradFirst) {
if (sequeue->size() > 1 && (*sequeue)[1] != nullptr && (*sequeue)[1]->isa<abstract::AbstractUndetermined>()) {
if (sequeue->size() > 1 && (*sequeue)[1] != nullptr &&
((*sequeue)[1]->isa<abstract::AbstractUndetermined>() ||
((*sequeue)[1]->BuildType() != nullptr && (*sequeue)[1]->BuildType()->isa<Number>()))) {
ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(1))}));
} else {
ret->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{})));
@ -413,7 +415,8 @@ FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &
for (size_t i = 1; i < sequeue->size(); ++i) {
if (tail_type_ == kGradAll) {
MS_EXCEPTION_IF_NULL((*sequeue)[i]);
if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>()) {
if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>() ||
((*sequeue)[i]->BuildType() != nullptr && (*sequeue)[i]->BuildType()->isa<Number>())) {
elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))}));
}
} else {

View File

@ -224,7 +224,8 @@ void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNode
// b = Load(para1, u2)
// u3 = UpdateState(u2, x)
void DeleteLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const CNodePtr &make_tuple, const AnfNodePtr &load) {
AnfNodePtr other_input = nullptr;
// Initialize the other_input with load in case of all the inputs of the make_tuple is the same load.
AnfNodePtr other_input = load;
for (size_t i = 1; i < make_tuple->size(); i++) {
if (make_tuple->input(i) != load) {
other_input = make_tuple->input(i);

View File

@ -489,7 +489,8 @@ void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
continue;
}
AbstractBasePtr par_abs = param_node->abstract();
if (par_abs->isa<abstract::AbstractUndetermined>()) {
if (par_abs->isa<abstract::AbstractUndetermined>() ||
(par_abs->BuildType() != nullptr && par_abs->BuildType()->isa<Number>())) {
new_paras.push_back(param_node);
}
}

View File

@ -98,7 +98,7 @@ std::string GetBaseNameForIR(int64_t stage_idx, const std::string &action_name)
AbstractBasePtr ArgsToAbstract(const ValuePtr &value) {
MS_EXCEPTION_IF_NULL(value);
bool broaden = value->isa<MetaTensor>();
bool broaden = value->isa<MetaTensor>() || value->isa<Scalar>();
return abstract::FromValue(value, broaden);
}
@ -142,6 +142,21 @@ std::string GetCompileExceptionInfo() {
return oss.str();
}
void SetGpuLoopSink(const ResourcePtr &resource_) {
auto func_graph = resource_->func_graph();
if (func_graph != nullptr && func_graph->manager() != nullptr) {
auto manager = func_graph->manager();
size_t graph_nums = manager->func_graphs().size();
int64_t sinksize = ConfigManager::GetInstance().iter_num();
if (graph_nums == 1) {
resource_->set_gpu_loopsink(true, sinksize);
} else {
resource_->set_gpu_loopsink(false, sinksize);
}
MS_LOG(INFO) << "Change gpu_loopsink_flag_ to " << resource_->gpu_loopsink_flag() << ", set loopsink size to "
<< sinksize;
}
}
} // namespace
py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::string, py::object> &defaults) {
@ -704,19 +719,7 @@ void Pipeline::Run() {
MS_LOG(DEBUG) << "Action " << action.first << " end.";
};
if (action.first == "task_emit") {
auto func_graph = resource_->func_graph();
if (func_graph != nullptr && func_graph->manager() != nullptr) {
auto manager = func_graph->manager();
size_t graph_nums = manager->func_graphs().size();
int64_t sinksize = ConfigManager::GetInstance().iter_num();
if (graph_nums == 1) {
resource_->set_gpu_loopsink(true, sinksize);
} else {
resource_->set_gpu_loopsink(false, sinksize);
}
MS_LOG(INFO) << "Change gpu_loopsink_flag_ to " << resource_->gpu_loopsink_flag() << ", set loopsink size to "
<< sinksize;
}
SetGpuLoopSink(resource_);
}
if (!result) {
MS_LOG(EXCEPTION) << "Pipeline running to end, failed in step:" << action.first;

View File

@ -210,7 +210,7 @@ class _MindSporeFunction:
return None
new_inputs = []
for i in args_list:
if isinstance(i, Tensor):
if isinstance(i, (Tensor, int, float)):
new_inputs.append(i)
return self._executor(tuple(new_inputs), phase)

View File

@ -88,7 +88,7 @@ std::string AbstractBase::ToString() const {
return buffer.str();
}
AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { return Clone(); }
AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { return AbstractBase::Broaden(config); }
AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
MS_EXCEPTION_IF_NULL(other);

View File

@ -171,10 +171,6 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p
return args_spec_list[0];
}
auto depends = args_spec_list[0]->Broaden();
// For scalar, need to set value to kAnyValue, because broaden scalar will not change the value.
if (depends->isa<AbstractScalar>()) {
depends->set_value(kAnyValue);
}
return depends;
}

View File

@ -609,7 +609,7 @@ class Cell(Cell_):
new_inputs = []
for i in inputs:
if isinstance(i, Tensor):
if isinstance(i, (Tensor, int, float)):
new_inputs.append(i)
if self._auto_parallel_mode:

View File

@ -199,10 +199,10 @@ class ForwardValueAndGrad(Cell):
If the sensor_param is True, a sensitivity (gradient with respect to output) needs to be transferred through
the location parameter or key-value pair parameter. If the value is transferred through the key-value pair
parameter, the key must be sens.
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
Inputs:
- **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
- sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
Outputs:
- **forward value** (a scalar Tensor with shape :math:`()`) - The result of network forward running.
@ -242,7 +242,7 @@ class ForwardValueAndGrad(Cell):
>>> loss, grads = forward_value_and_grad(inputs, labels, 1.0)
"""
def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False):
def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False, sens=1.0):
super(ForwardValueAndGrad, self).__init__(auto_prefix=False)
if not isinstance(network, (Cell, FunctionType, MethodType)):
raise TypeError(f"The type of training network should be cell, function type or method type, "
@ -259,19 +259,16 @@ class ForwardValueAndGrad(Cell):
self.get_all = get_all
self.get_by_list = get_by_list
self.sens_param = sens_param
self.sens = sens
self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param)
def construct(self, *inputs):
weights = self.weights
if self.sens_param:
sens = inputs[-1]
inputs = inputs[:-1]
else:
sens = None
loss = self.network(*inputs)
if self.sens_param:
if not isinstance(sens, Tensor):
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), sens)
sens = self.sens
if not isinstance(self.sens, Tensor):
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*inputs, sens)
else:
grads = self.grad(self.network, weights)(*inputs)

View File

@ -223,7 +223,8 @@ class DType(PrimitiveWithInfer):
"""Initialize DType"""
def __infer__(self, x):
validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name)
addition_error_info = 'Perhaps you are using a mixture of tensors and scalars to operate.'
validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name, addition_error_info)
out = {'shape': (),
'dtype': mstype.type_type,
'value': x['dtype'].element_type()}

View File

@ -414,13 +414,14 @@ def test_trainTensor_with_new_interface(num_classes=10, epoch=8, batch_size=1):
weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters()))
optimizer = Momentum(weights, 0.1, 0.9)
train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True)
train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True,
sens=1.0)
losses = []
for i in range(0, epoch):
data = Tensor(np.ones([batch_size, 3, 224, 224]
).astype(np.float32) * 0.01)
label = Tensor(np.ones([batch_size]).astype(np.int32))
loss, grads = train_network(data, label, 1.0)
loss, grads = train_network(data, label)
grads = F.identity(grads)
optimizer(grads)
losses.append(loss)
@ -439,13 +440,14 @@ def test_big_batchSize_with_new_interface(num_classes=10, epoch=8, batch_size=33
weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters()))
optimizer = Momentum(weights, 0.1, 0.9)
train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True)
train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True,
sens=1.0)
losses = []
for i in range(0, epoch):
data = Tensor(np.ones([batch_size, 3, 224, 224]
).astype(np.float32) * 0.01)
label = Tensor(np.ones([batch_size]).astype(np.int32))
loss, grads = train_network(data, label, 1.0)
loss, grads = train_network(data, label)
grads = F.identity(grads)
optimizer(grads)
losses.append(loss)

View File

@ -95,15 +95,23 @@ def test_ReduceAll():
assert output[3].shape == expect3.shape
x_1 = np.array([[True, True], [True, False], [False, False]])
axis_1 = 0
x_2 = np.array([[True, True], [True, True], [True, False], [False, False]])
axis_2 = 0
class ReduceAllDynamic(nn.Cell):
def __init__(self):
def __init__(self, x, axis):
super(ReduceAllDynamic, self).__init__()
self.reduceall = P.ReduceAll(False)
self.test_dynamic = inner.GpuConvertToDynamicShape()
self.x = x
self.axis = axis
def construct(self, x, axis):
x = self.test_dynamic(x)
return self.reduceall(x, axis)
def construct(self):
dynamic_x = self.test_dynamic(self.x)
return self.reduceall(dynamic_x, self.axis)
@pytest.mark.level0
@ -111,18 +119,14 @@ class ReduceAllDynamic(nn.Cell):
@pytest.mark.env_onecard
def test_reduce_all_dynamic():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = ReduceAllDynamic()
net1 = ReduceAllDynamic(Tensor(x_1), axis_1)
net2 = ReduceAllDynamic(Tensor(x_2), axis_2)
x_1 = np.array([[True, True], [True, False], [False, False]])
axis_1 = 0
expect_1 = np.all(x_1, axis=axis_1, keepdims=False)
x_2 = np.array([[True, True], [True, True], [True, False], [False, False]])
axis_2 = 0
expect_2 = np.all(x_2, axis=axis_2, keepdims=False)
output_1 = net(Tensor(x_1), axis_1)
output_2 = net(Tensor(x_2), axis_2)
output1 = net1()
output2 = net2()
np.testing.assert_almost_equal(output_1.asnumpy(), expect_1)
np.testing.assert_almost_equal(output_2.asnumpy(), expect_2)
np.testing.assert_almost_equal(output1.asnumpy(), expect_1)
np.testing.assert_almost_equal(output2.asnumpy(), expect_2)

View File

@ -95,15 +95,23 @@ def test_ReduceAny():
assert output[3].shape == expect3.shape
x_1 = np.array([[True, True], [True, False], [False, False]])
axis_1 = 0
x_2 = np.array([[True, True], [True, True], [True, False], [False, False]])
axis_2 = 0
class ReduceAnyDynamic(nn.Cell):
def __init__(self):
def __init__(self, x, axis):
super(ReduceAnyDynamic, self).__init__()
self.reduceany = P.ReduceAny(False)
self.test_dynamic = inner.GpuConvertToDynamicShape()
self.x = x
self.axis = axis
def construct(self, x, axis):
x = self.test_dynamic(x)
return self.reduceany(x, axis)
def construct(self):
dynamic_x = self.test_dynamic(self.x)
return self.reduceany(dynamic_x, self.axis)
@pytest.mark.level0
@ -111,18 +119,14 @@ class ReduceAnyDynamic(nn.Cell):
@pytest.mark.env_onecard
def test_reduce_any_dynamic():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = ReduceAnyDynamic()
net1 = ReduceAnyDynamic(Tensor(x_1), axis_1)
net2 = ReduceAnyDynamic(Tensor(x_2), axis_2)
x_1 = np.array([[True, True], [True, False], [False, False]])
axis_1 = 0
expect_1 = np.any(x_1, axis=axis_1, keepdims=False)
x_2 = np.array([[True, True], [True, True], [True, False], [False, False]])
axis_2 = 0
expect_2 = np.any(x_2, axis=axis_2, keepdims=False)
output_1 = net(Tensor(x_1), axis_1)
output_2 = net(Tensor(x_2), axis_2)
output1 = net1()
output2 = net2()
np.testing.assert_almost_equal(output_1.asnumpy(), expect_1)
np.testing.assert_almost_equal(output_2.asnumpy(), expect_2)
np.testing.assert_almost_equal(output1.asnumpy(), expect_1)
np.testing.assert_almost_equal(output2.asnumpy(), expect_2)

View File

@ -179,36 +179,41 @@ def test_ReduceMax():
assert np.all(diff8 < error8)
x_1 = x8
axis_1 = 0
x_2 = x1
axis_2 = 0
class ReduceMaxDynamic(nn.Cell):
def __init__(self):
def __init__(self, x, axis):
super(ReduceMaxDynamic, self).__init__()
self.reducemax = P.ReduceMax(False)
self.test_dynamic = inner.GpuConvertToDynamicShape()
self.x = x
self.axis = axis
def construct(self, x, axis):
x = self.test_dynamic(x)
return self.reducemax(x, axis)
def construct(self):
dynamic_x = self.test_dynamic(self.x)
return self.reducemax(dynamic_x, self.axis)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_reduce_max_dynamic():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = ReduceMaxDynamic()
net1 = ReduceMaxDynamic(Tensor(x_1), axis_1)
net2 = ReduceMaxDynamic(Tensor(x_2), axis_2)
x_1 = x8
axis_1 = 0
expect_1 = np.max(x_1, axis=0, keepdims=False)
x_2 = x1
axis_2 = 0
expect_2 = np.max(x_2, axis=0, keepdims=False)
output_1 = net(Tensor(x_1), axis_1)
output_2 = net(Tensor(x_2), axis_2)
output1 = net1()
output2 = net2()
np.testing.assert_almost_equal(output1.asnumpy(), expect_1)
np.testing.assert_almost_equal(output2.asnumpy(), expect_2)
np.testing.assert_almost_equal(output_1.asnumpy(), expect_1)
np.testing.assert_almost_equal(output_2.asnumpy(), expect_2)
class ReduceMaxTypeNet(nn.Cell):
def __init__(self, nptype):

View File

@ -268,14 +268,16 @@ def test_ReduceMean():
assert output[14].shape == expect14.shape
class ReduceMeanDynamic(nn.Cell):
def __init__(self, keepdims=False):
def __init__(self, x, axis, keepdims=False):
super(ReduceMeanDynamic, self).__init__()
self.test_dynamic = inner.GpuConvertToDynamicShape()
self.reducemean = P.ReduceMean(keep_dims=keepdims)
self.x = x
self.axis = axis
def construct(self, input_x, axis):
input_x = self.test_dynamic(input_x)
output = self.reducemean(input_x, axis)
def construct(self):
dynamic_x = self.test_dynamic(self.x)
output = self.reducemean(dynamic_x, self.axis)
return output
@pytest.mark.level0
@ -283,32 +285,30 @@ class ReduceMeanDynamic(nn.Cell):
@pytest.mark.env_onecard
def test_dynamic_reduce_mean_keepdims_true():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = ReduceMeanDynamic(keepdims=True)
x_tensor_1 = Tensor(x14)
output_1 = net(x_tensor_1, axis14)
x_tensor_2 = Tensor(x0)
output_2 = net(x_tensor_2, axis0)
net1 = ReduceMeanDynamic(Tensor(x14), axis14, keepdims=True)
net2 = ReduceMeanDynamic(Tensor(x0), axis0, keepdims=True)
output1 = net1()
output2 = net2()
expect_1 = np.mean(x14, axis=np_axis14, keepdims=True)
diff_1 = abs(output_1.asnumpy() - expect_1)
diff_1 = abs(output1.asnumpy() - expect_1)
error_1 = np.ones(shape=expect_1.shape) * 1.0e-5
assert np.all(diff_1 < error_1)
assert output_1.shape == expect_1.shape
assert output1.shape == expect_1.shape
expect_2 = np.mean(x0, axis=axis0, keepdims=True)
diff_2 = abs(output_2.asnumpy() - expect_2)
diff_2 = abs(output2.asnumpy() - expect_2)
error_2 = np.ones(shape=expect_2.shape) * 1.0e-5
assert np.all(diff_2 < error_2)
assert output_2.shape == expect_2.shape
assert output2.shape == expect_2.shape
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_dynamic_reduce_mean_keepdims_false():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = ReduceMeanDynamic(keepdims=False)
x_tensor = Tensor(x12)
output = net(x_tensor, axis12)
net = ReduceMeanDynamic(Tensor(x12), axis12, keepdims=False)
output = net()
expect = np.mean(x12, axis=axis12, keepdims=False)
diff = abs(output.asnumpy() - expect)

View File

@ -179,33 +179,37 @@ def test_ReduceMin():
assert np.all(diff8 < error8)
x_1 = x8
axis_1 = 0
x_2 = x1
axis_2 = 0
class ReduceMinDynamic(nn.Cell):
def __init__(self):
def __init__(self, x, axis):
super(ReduceMinDynamic, self).__init__()
self.reducemin = P.ReduceMin(False)
self.test_dynamic = inner.GpuConvertToDynamicShape()
self.x = x
self.axis = axis
def construct(self, x, axis):
x = self.test_dynamic(x)
return self.reducemin(x, axis)
def construct(self):
dynamic_x = self.test_dynamic(self.x)
return self.reducemin(dynamic_x, self.axis)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_reduce_min_dynamic():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = ReduceMinDynamic()
net1 = ReduceMinDynamic(Tensor(x_1), axis_1)
net2 = ReduceMinDynamic(Tensor(x_2), axis_2)
x_1 = x8
axis_1 = 0
expect_1 = np.min(x_1, axis=0, keepdims=False)
x_2 = x1
axis_2 = 0
expect_2 = np.min(x_2, axis=0, keepdims=False)
output_1 = net(Tensor(x_1), axis_1)
output_2 = net(Tensor(x_2), axis_2)
output1 = net1()
output2 = net2()
np.testing.assert_almost_equal(output_1.asnumpy(), expect_1)
np.testing.assert_almost_equal(output_2.asnumpy(), expect_2)
np.testing.assert_almost_equal(output1.asnumpy(), expect_1)
np.testing.assert_almost_equal(output2.asnumpy(), expect_2)

View File

@ -270,15 +270,23 @@ def test_ReduceSum():
assert output[14].shape == expect14.shape
x_1 = x8
axis_1 = 0
x_2 = x1
axis_2 = 0
class ReduceSumDynamic(nn.Cell):
def __init__(self):
def __init__(self, x, axis):
super(ReduceSumDynamic, self).__init__()
self.reducesum = P.ReduceSum(True)
self.test_dynamic = inner.GpuConvertToDynamicShape()
self.x = x
self.axis = axis
def construct(self, x, axis):
x = self.test_dynamic(x)
return self.reducesum(x, axis)
def construct(self):
dynamic_x = self.test_dynamic(self.x)
return self.reducesum(dynamic_x, self.axis)
@pytest.mark.level0
@ -286,21 +294,18 @@ class ReduceSumDynamic(nn.Cell):
@pytest.mark.env_onecard
def test_reduce_sum_dynamic():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = ReduceSumDynamic()
net1 = ReduceSumDynamic(Tensor(x_1), axis_1)
net2 = ReduceSumDynamic(Tensor(x_2), axis_2)
x_1 = x8
axis_1 = 0
expect_1 = np.sum(x_1, axis=axis_1, keepdims=True)
x_2 = x1
axis_2 = 0
expect_2 = np.sum(x_2, axis=axis_2, keepdims=True)
output_1 = net(Tensor(x_1), axis_1)
output_2 = net(Tensor(x_2), axis_2)
output1 = net1()
output2 = net2()
np.testing.assert_almost_equal(output1.asnumpy(), expect_1)
np.testing.assert_almost_equal(output2.asnumpy(), expect_2)
np.testing.assert_almost_equal(output_1.asnumpy(), expect_1)
np.testing.assert_almost_equal(output_2.asnumpy(), expect_2)
class ReduceSumTypeNet(nn.Cell):
def __init__(self, nptype):

View File

@ -32,18 +32,26 @@ TEST_F(TestUtils, test_join) {
AbstractBasePtr abs_s1 = FromValue(static_cast<int64_t>(1), false);
AbstractBasePtr abs_s2 = FromValue(static_cast<int64_t>(2), false);
AbstractBasePtr abs_s_anything = FromValue(static_cast<int64_t>(2), true);
abs_s_anything->set_value(kAnyValue);
AbstractBasePtr res_s1 = abs_s1->Join(abs_s2);
ASSERT_EQ(*res_s1, *abs_s_anything);
// AbstractTuple join;
std::vector<int64_t> list1 = {1, 2, 3, 4, 5};
std::vector<int64_t> list2 = {5, 4, 3, 2, 1};
AbstractBasePtr abs_t1 = FromValue(list1, true);
AbstractBasePtr abs_t2 = FromValue(list2, true);
AbstractBasePtr res_t1 = abs_t1->Join(abs_t2);
ASSERT_EQ(res_t1, abs_t1);
abs_s1 = FromValue(static_cast<int64_t>(1), false);
AbstractBasePtr t1 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything}));
AbstractBasePtr t2 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything}));
AbstractBasePtr t3 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s_anything, abs_s_anything}));
AbstractBasePtr res_t1 = t1->Join(t2);
res_t1 = t1->Join(t2);
ASSERT_EQ(res_t1, t1);
res_t1 = t1->Join(t3);

View File

@ -111,11 +111,8 @@ TEST_F(TestOptLib, test_inline) {
// add infer and renormalize
std::shared_ptr<mindspore::pipeline::Resource> res = std::make_shared<mindspore::pipeline::Resource>();
AbstractBasePtrList args_spec_list;
tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{2, 3});
tensor::TensorPtr y_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{2, 3});
AbstractBasePtr abstract_v1 = abstract::FromValue(x_tensor, true);
AbstractBasePtr abstract_v2 = abstract::FromValue(y_tensor, true);
AbstractBasePtr abstract_v1 = abstract::FromValue(static_cast<int64_t>(1), true);
AbstractBasePtr abstract_v2 = abstract::FromValue(static_cast<int64_t>(2), true);
args_spec_list.push_back(abstract_v1);
args_spec_list.push_back(abstract_v2);
AnalysisResult result = pipeline::AbstractAnalyze(res, before1, args_spec_list);

View File

@ -184,7 +184,7 @@ TEST_F(TestData, test_broaden) {
AbstractBasePtr s2 = s1->Broaden();
ASSERT_TRUE(*s1->GetTypeTrack() == *s2->GetTypeTrack());
ASSERT_TRUE(*s1->GetValueTrack() == *MakeValue(int1));
ASSERT_TRUE(s2->GetValueTrack()->isa<Int64Imm>());
ASSERT_TRUE(s2->GetValueTrack()->isa<AnyValue>());
AbstractFunctionPtr f1 = std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(),
AnalysisContext::DummyContext());
@ -196,7 +196,7 @@ TEST_F(TestData, test_broaden) {
AbstractList* l2_cast = dynamic_cast<AbstractList*>(l2.get());
ASSERT_TRUE(l2_cast != nullptr);
AbstractBasePtr csr = AbstractJoin(l2_cast->elements());
ASSERT_TRUE(csr->GetValueTrack()->isa<Int64Imm>());
ASSERT_TRUE(csr->GetValueTrack()->isa<AnyValue>());
}
} // namespace abstract

View File

@ -20,7 +20,6 @@ from mindspore import Tensor, Parameter
from mindspore import context
from mindspore import dtype as mstype
from mindspore.nn import Cell
from mindspore.ops import operations as P
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \
@ -684,27 +683,6 @@ def test_tensor_assign_bool_index():
net4(Ta, Tensor(u_scalar, mstype.int32))
def test_trivial_call_function_twice_with_diff_key_value_para():
class Net(Cell):
def __init__(self):
super(Net, self).__init__()
self.arange = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))
self.concat = P.Concat(axis=0)
def compute(self, x, is_decoder):
if is_decoder:
return self.arange[:x]
return self.arange[1:x + 1]
def construct(self):
result1 = self.compute(7, is_decoder=True)
result2 = self.compute(6, is_decoder=False)
return self.concat((result1, result2))
net = Net()
net()
test_cases = [
('TensorAssignWithTupleEllipsis2', {
'block': TensorAssignWithTupleEllipsis2(),

View File

@ -19,7 +19,7 @@ from mindspore import Tensor, ms_function
from mindspore import context
from mindspore.ops import operations as P
context.set_context(mode=context.PYNATIVE_MODE, save_graphs=True)
context.set_context(mode=context.PYNATIVE_MODE)
@ms_function
@ -33,8 +33,7 @@ def test_scalar_compute():
p = (3, 4)
q = [5, 6]
w = {"x": 7, "y": 8}
ret = compute(int_x, int_y, p, q, w)
assert ret == -1
compute(int_x, int_y, p, q, w)
def test_tensor_compute():

View File

@ -45,6 +45,17 @@ class GradNet(nn.Cell):
return self.grad_all(self.forward_net)(tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag)
class GradNet1(nn.Cell):
def __init__(self, net, get_all):
super(GradNet1, self).__init__()
self.forward_net = net
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
self.grad_all = C.GradOperation(get_all=get_all)
def construct(self, tuple_a, tensor_x, list_b, tensor_y, tensor_z, dict_c):
return self.grad_all(self.forward_net)(tuple_a, tensor_x, list_b, tensor_y, tensor_z, dict_c)
x = Tensor(np.ones((2, 2), np.float32))
y = Tensor(np.ones((2, 2), np.float32) * 2)
z = Tensor(np.ones((2, 2), np.float32) * 3)
@ -68,33 +79,18 @@ forward_net = FirstInputTupleNet()
grad_all_inputs_net = GradNet(forward_net, get_all=True)
def test_outermost_net_inputs_including_non_tensor():
forward_net(arg_t0, z, arg_l0, w, sl, args_d0, flag_0)
forward_net(arg_t1, z, arg_l1, x, sl, args_d1, flag_1)
def test_grad_net_inputs_including_non_tensor():
assert len(grad_all_inputs_net(arg_t0, z, arg_l0, w, sl, args_d0, flag_0)) == 2
assert len(grad_all_inputs_net(arg_t1, z, arg_l1, x, sl, args_d1, flag_1)) == 2
def test_grad_first_input_net():
class FirstInputTensorNet(nn.Cell):
def __init__(self):
super(FirstInputTensorNet, self).__init__()
def construct(self, tensor_x, tuple_a, list_b, tensor_y, scalar, dict_c, flag):
if flag:
return tensor_x - tuple_a[2] + list_b[1][1]["x"] - tensor_y + scalar - dict_c["x"]
return tensor_x + tuple_a[2] - list_b[1][1]["y"] + tensor_y - scalar + dict_c["y"]
def construct(self, tensor_x, tuple_a, list_b, tensor_y, tensor_z, dict_c):
return tensor_x + tuple_a[2] - list_b[1][1]["y"] + tensor_y - tensor_z + dict_c["y"]
grad_fist_input_tensor_net = GradNet(FirstInputTensorNet(), get_all=False)
ret = grad_fist_input_tensor_net(z, arg_t0, arg_l0, w, sl, args_d0, flag_0)
grad_fist_input_tensor_net = GradNet1(FirstInputTensorNet(), get_all=False)
ret = grad_fist_input_tensor_net(z, arg_t0, arg_l0, w, y, args_d0)
assert np.allclose(ret.asnumpy(), np.ones((2, 2), np.float32))
grad_fist_input_tuple_net = GradNet(forward_net, get_all=False)
assert not grad_fist_input_tuple_net(arg_t0, z, arg_l0, w, sl, args_d0, flag_0)
def test_net_inputs_including_str():
with pytest.raises(TypeError) as err:

View File

@ -14,7 +14,7 @@
# ============================================================================
""" test_framstruct """
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import context
@ -76,11 +76,13 @@ def dynamic_make_tuple(x, lower, upper):
def test_dynamic_make_tuple():
assert dynamic_make_tuple(2, 1, 5) == (2, 2, 2, 2)
# Dynamically recursively creating static type is invalid in mindspore, as mindspore is a static language.
with pytest.raises(RuntimeError):
dynamic_make_tuple(2, 1, 5)
def test_make_tuple():
# Staticly recursively creating static type is valid in mindspore.
# Statically recursively creating static type is valid in mindspore.
@ms_function
def make_tuple(x):
out = ()