From 41189781f3fe3f4edd4318b109fa9ac2b72d5c10 Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Sun, 7 Feb 2021 17:04:26 +0800 Subject: [PATCH] support scalar input for cell --- mindspore/_checkparam.py | 7 ++-- .../frontend/operator/composite/composite.cc | 7 ++-- mindspore/ccsrc/frontend/optimizer/cse.cc | 3 +- mindspore/ccsrc/pipeline/jit/pass.cc | 3 +- mindspore/ccsrc/pipeline/jit/pipeline.cc | 31 +++++++++-------- mindspore/common/api.py | 2 +- mindspore/core/abstract/abstract_value.cc | 2 +- mindspore/core/abstract/prim_others.cc | 4 --- mindspore/nn/cell.py | 2 +- mindspore/nn/wrap/cell_wrapper.py | 15 ++++---- mindspore/ops/operations/array_ops.py | 3 +- tests/st/networks/test_gpu_resnet.py | 10 +++--- tests/st/ops/gpu/test_reduce_all_op.py | 32 +++++++++-------- tests/st/ops/gpu/test_reduce_any_op.py | 32 +++++++++-------- tests/st/ops/gpu/test_reduce_max_op.py | 33 ++++++++++-------- tests/st/ops/gpu/test_reduce_mean_op.py | 32 ++++++++--------- tests/st/ops/gpu/test_reduce_min_op.py | 32 +++++++++-------- tests/st/ops/gpu/test_reduce_sum_op.py | 33 ++++++++++-------- tests/ut/cpp/abstract/utils_test.cc | 12 +++++-- tests/ut/cpp/optimizer/lib_test.cc | 7 ++-- .../cpp/pipeline/static_analysis/data_test.cc | 4 +-- tests/ut/python/ops/test_tensor_slice.py | 22 ------------ ...test_ms_function_pass_non_tensor_inputs.py | 5 ++- ...st_outermost_net_pass_non_tensor_inputs.py | 34 ++++++++----------- .../python/pynative_mode/test_framstruct.py | 8 +++-- 25 files changed, 192 insertions(+), 183 deletions(-) diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index d5f8444d45a..1c7668768c0 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -453,7 +453,7 @@ class Validator: return padding @staticmethod - def check_subclass(arg_name, type_, template_types, prim_name): + def check_subclass(arg_name, type_, template_types, prim_name, addition_error_info=None): """Checks whether some type is subclass of another type""" if not isinstance(template_types, Iterable): template_types = (template_types,) @@ -467,9 +467,12 @@ class Validator: hit = True break if not hit: + if addition_error_info is None: + addition_error_info = '' type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_) raise TypeError(f'For \'{prim_name}\', the type of `{arg_name}` should be subclass' - f' of {", ".join((str(x) for x in template_types))}, but got {type_str}.') + f' of {", ".join((str(x) for x in template_types))}, but got {type_str}.' + f' {addition_error_info}') @staticmethod def check_const_input(arg_name, arg_value, prim_name): diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.cc b/mindspore/ccsrc/frontend/operator/composite/composite.cc index 8aabad979b1..ef76c4d4985 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.cc +++ b/mindspore/ccsrc/frontend/operator/composite/composite.cc @@ -401,7 +401,9 @@ FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr & } if (tail_type_ == kGradFirst) { - if (sequeue->size() > 1 && (*sequeue)[1] != nullptr && (*sequeue)[1]->isa()) { + if (sequeue->size() > 1 && (*sequeue)[1] != nullptr && + ((*sequeue)[1]->isa() || + ((*sequeue)[1]->BuildType() != nullptr && (*sequeue)[1]->BuildType()->isa()))) { ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(1))})); } else { ret->set_output(NewValueNode(std::make_shared(std::vector{}))); @@ -413,7 +415,8 @@ FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr & for (size_t i = 1; i < sequeue->size(); ++i) { if (tail_type_ == kGradAll) { MS_EXCEPTION_IF_NULL((*sequeue)[i]); - if ((*sequeue)[i]->isa()) { + if ((*sequeue)[i]->isa() || + ((*sequeue)[i]->BuildType() != nullptr && (*sequeue)[i]->BuildType()->isa())) { elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))})); } } else { diff --git a/mindspore/ccsrc/frontend/optimizer/cse.cc b/mindspore/ccsrc/frontend/optimizer/cse.cc index c8305219db3..76c84e85e61 100644 --- a/mindspore/ccsrc/frontend/optimizer/cse.cc +++ b/mindspore/ccsrc/frontend/optimizer/cse.cc @@ -224,7 +224,8 @@ void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNode // b = Load(para1, u2) // u3 = UpdateState(u2, x) void DeleteLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const CNodePtr &make_tuple, const AnfNodePtr &load) { - AnfNodePtr other_input = nullptr; + // Initialize the other_input with load in case of all the inputs of the make_tuple is the same load. + AnfNodePtr other_input = load; for (size_t i = 1; i < make_tuple->size(); i++) { if (make_tuple->input(i) != load) { other_input = make_tuple->input(i); diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 8ef2e39a942..09401c4b892 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -489,7 +489,8 @@ void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) { continue; } AbstractBasePtr par_abs = param_node->abstract(); - if (par_abs->isa()) { + if (par_abs->isa() || + (par_abs->BuildType() != nullptr && par_abs->BuildType()->isa())) { new_paras.push_back(param_node); } } diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 63a1eeccc1d..5f6ffb84990 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -98,7 +98,7 @@ std::string GetBaseNameForIR(int64_t stage_idx, const std::string &action_name) AbstractBasePtr ArgsToAbstract(const ValuePtr &value) { MS_EXCEPTION_IF_NULL(value); - bool broaden = value->isa(); + bool broaden = value->isa() || value->isa(); return abstract::FromValue(value, broaden); } @@ -142,6 +142,21 @@ std::string GetCompileExceptionInfo() { return oss.str(); } +void SetGpuLoopSink(const ResourcePtr &resource_) { + auto func_graph = resource_->func_graph(); + if (func_graph != nullptr && func_graph->manager() != nullptr) { + auto manager = func_graph->manager(); + size_t graph_nums = manager->func_graphs().size(); + int64_t sinksize = ConfigManager::GetInstance().iter_num(); + if (graph_nums == 1) { + resource_->set_gpu_loopsink(true, sinksize); + } else { + resource_->set_gpu_loopsink(false, sinksize); + } + MS_LOG(INFO) << "Change gpu_loopsink_flag_ to " << resource_->gpu_loopsink_flag() << ", set loopsink size to " + << sinksize; + } +} } // namespace py::tuple GenerateKey(const std::string &name, const std::unordered_map &defaults) { @@ -704,19 +719,7 @@ void Pipeline::Run() { MS_LOG(DEBUG) << "Action " << action.first << " end."; }; if (action.first == "task_emit") { - auto func_graph = resource_->func_graph(); - if (func_graph != nullptr && func_graph->manager() != nullptr) { - auto manager = func_graph->manager(); - size_t graph_nums = manager->func_graphs().size(); - int64_t sinksize = ConfigManager::GetInstance().iter_num(); - if (graph_nums == 1) { - resource_->set_gpu_loopsink(true, sinksize); - } else { - resource_->set_gpu_loopsink(false, sinksize); - } - MS_LOG(INFO) << "Change gpu_loopsink_flag_ to " << resource_->gpu_loopsink_flag() << ", set loopsink size to " - << sinksize; - } + SetGpuLoopSink(resource_); } if (!result) { MS_LOG(EXCEPTION) << "Pipeline running to end, failed in step:" << action.first; diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 1312e8ee4a8..257ee65e195 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -210,7 +210,7 @@ class _MindSporeFunction: return None new_inputs = [] for i in args_list: - if isinstance(i, Tensor): + if isinstance(i, (Tensor, int, float)): new_inputs.append(i) return self._executor(tuple(new_inputs), phase) diff --git a/mindspore/core/abstract/abstract_value.cc b/mindspore/core/abstract/abstract_value.cc index 9a4ec269a62..8dd3e420658 100644 --- a/mindspore/core/abstract/abstract_value.cc +++ b/mindspore/core/abstract/abstract_value.cc @@ -88,7 +88,7 @@ std::string AbstractBase::ToString() const { return buffer.str(); } -AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { return Clone(); } +AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { return AbstractBase::Broaden(config); } AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { MS_EXCEPTION_IF_NULL(other); diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index 26a9b080cf2..c8af8416be3 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -171,10 +171,6 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p return args_spec_list[0]; } auto depends = args_spec_list[0]->Broaden(); - // For scalar, need to set value to kAnyValue, because broaden scalar will not change the value. - if (depends->isa()) { - depends->set_value(kAnyValue); - } return depends; } diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 715b141fc49..68e73152115 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -609,7 +609,7 @@ class Cell(Cell_): new_inputs = [] for i in inputs: - if isinstance(i, Tensor): + if isinstance(i, (Tensor, int, float)): new_inputs.append(i) if self._auto_parallel_mode: diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index db15a98a3c8..b11b9dc9dc6 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -199,10 +199,10 @@ class ForwardValueAndGrad(Cell): If the sensor_param is True, a sensitivity (gradient with respect to output) needs to be transferred through the location parameter or key-value pair parameter. If the value is transferred through the key-value pair parameter, the key must be sens. + sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. Inputs: - **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. - - sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. Outputs: - **forward value** (a scalar Tensor with shape :math:`()`) - The result of network forward running. @@ -242,7 +242,7 @@ class ForwardValueAndGrad(Cell): >>> loss, grads = forward_value_and_grad(inputs, labels, 1.0) """ - def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False): + def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False, sens=1.0): super(ForwardValueAndGrad, self).__init__(auto_prefix=False) if not isinstance(network, (Cell, FunctionType, MethodType)): raise TypeError(f"The type of training network should be cell, function type or method type, " @@ -259,19 +259,16 @@ class ForwardValueAndGrad(Cell): self.get_all = get_all self.get_by_list = get_by_list self.sens_param = sens_param + self.sens = sens self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param) def construct(self, *inputs): weights = self.weights - if self.sens_param: - sens = inputs[-1] - inputs = inputs[:-1] - else: - sens = None loss = self.network(*inputs) if self.sens_param: - if not isinstance(sens, Tensor): - sens = P.Fill()(P.DType()(loss), P.Shape()(loss), sens) + sens = self.sens + if not isinstance(self.sens, Tensor): + sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) grads = self.grad(self.network, weights)(*inputs, sens) else: grads = self.grad(self.network, weights)(*inputs) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 7f976014fab..a4af75548f9 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -223,7 +223,8 @@ class DType(PrimitiveWithInfer): """Initialize DType""" def __infer__(self, x): - validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name) + addition_error_info = 'Perhaps you are using a mixture of tensors and scalars to operate.' + validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name, addition_error_info) out = {'shape': (), 'dtype': mstype.type_type, 'value': x['dtype'].element_type()} diff --git a/tests/st/networks/test_gpu_resnet.py b/tests/st/networks/test_gpu_resnet.py index 4e4d4a8c326..85f2a2a1f83 100644 --- a/tests/st/networks/test_gpu_resnet.py +++ b/tests/st/networks/test_gpu_resnet.py @@ -414,13 +414,14 @@ def test_trainTensor_with_new_interface(num_classes=10, epoch=8, batch_size=1): weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters())) optimizer = Momentum(weights, 0.1, 0.9) - train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True) + train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True, + sens=1.0) losses = [] for i in range(0, epoch): data = Tensor(np.ones([batch_size, 3, 224, 224] ).astype(np.float32) * 0.01) label = Tensor(np.ones([batch_size]).astype(np.int32)) - loss, grads = train_network(data, label, 1.0) + loss, grads = train_network(data, label) grads = F.identity(grads) optimizer(grads) losses.append(loss) @@ -439,13 +440,14 @@ def test_big_batchSize_with_new_interface(num_classes=10, epoch=8, batch_size=33 weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters())) optimizer = Momentum(weights, 0.1, 0.9) - train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True) + train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True, + sens=1.0) losses = [] for i in range(0, epoch): data = Tensor(np.ones([batch_size, 3, 224, 224] ).astype(np.float32) * 0.01) label = Tensor(np.ones([batch_size]).astype(np.int32)) - loss, grads = train_network(data, label, 1.0) + loss, grads = train_network(data, label) grads = F.identity(grads) optimizer(grads) losses.append(loss) diff --git a/tests/st/ops/gpu/test_reduce_all_op.py b/tests/st/ops/gpu/test_reduce_all_op.py index bfc36c018fd..9daf7a04c50 100644 --- a/tests/st/ops/gpu/test_reduce_all_op.py +++ b/tests/st/ops/gpu/test_reduce_all_op.py @@ -95,15 +95,23 @@ def test_ReduceAll(): assert output[3].shape == expect3.shape +x_1 = np.array([[True, True], [True, False], [False, False]]) +axis_1 = 0 +x_2 = np.array([[True, True], [True, True], [True, False], [False, False]]) +axis_2 = 0 + + class ReduceAllDynamic(nn.Cell): - def __init__(self): + def __init__(self, x, axis): super(ReduceAllDynamic, self).__init__() self.reduceall = P.ReduceAll(False) self.test_dynamic = inner.GpuConvertToDynamicShape() + self.x = x + self.axis = axis - def construct(self, x, axis): - x = self.test_dynamic(x) - return self.reduceall(x, axis) + def construct(self): + dynamic_x = self.test_dynamic(self.x) + return self.reduceall(dynamic_x, self.axis) @pytest.mark.level0 @@ -111,18 +119,14 @@ class ReduceAllDynamic(nn.Cell): @pytest.mark.env_onecard def test_reduce_all_dynamic(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - net = ReduceAllDynamic() + net1 = ReduceAllDynamic(Tensor(x_1), axis_1) + net2 = ReduceAllDynamic(Tensor(x_2), axis_2) - x_1 = np.array([[True, True], [True, False], [False, False]]) - axis_1 = 0 expect_1 = np.all(x_1, axis=axis_1, keepdims=False) - - x_2 = np.array([[True, True], [True, True], [True, False], [False, False]]) - axis_2 = 0 expect_2 = np.all(x_2, axis=axis_2, keepdims=False) - output_1 = net(Tensor(x_1), axis_1) - output_2 = net(Tensor(x_2), axis_2) + output1 = net1() + output2 = net2() - np.testing.assert_almost_equal(output_1.asnumpy(), expect_1) - np.testing.assert_almost_equal(output_2.asnumpy(), expect_2) + np.testing.assert_almost_equal(output1.asnumpy(), expect_1) + np.testing.assert_almost_equal(output2.asnumpy(), expect_2) diff --git a/tests/st/ops/gpu/test_reduce_any_op.py b/tests/st/ops/gpu/test_reduce_any_op.py index 51874a699a6..7da47b09ead 100644 --- a/tests/st/ops/gpu/test_reduce_any_op.py +++ b/tests/st/ops/gpu/test_reduce_any_op.py @@ -95,15 +95,23 @@ def test_ReduceAny(): assert output[3].shape == expect3.shape +x_1 = np.array([[True, True], [True, False], [False, False]]) +axis_1 = 0 +x_2 = np.array([[True, True], [True, True], [True, False], [False, False]]) +axis_2 = 0 + + class ReduceAnyDynamic(nn.Cell): - def __init__(self): + def __init__(self, x, axis): super(ReduceAnyDynamic, self).__init__() self.reduceany = P.ReduceAny(False) self.test_dynamic = inner.GpuConvertToDynamicShape() + self.x = x + self.axis = axis - def construct(self, x, axis): - x = self.test_dynamic(x) - return self.reduceany(x, axis) + def construct(self): + dynamic_x = self.test_dynamic(self.x) + return self.reduceany(dynamic_x, self.axis) @pytest.mark.level0 @@ -111,18 +119,14 @@ class ReduceAnyDynamic(nn.Cell): @pytest.mark.env_onecard def test_reduce_any_dynamic(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - net = ReduceAnyDynamic() + net1 = ReduceAnyDynamic(Tensor(x_1), axis_1) + net2 = ReduceAnyDynamic(Tensor(x_2), axis_2) - x_1 = np.array([[True, True], [True, False], [False, False]]) - axis_1 = 0 expect_1 = np.any(x_1, axis=axis_1, keepdims=False) - - x_2 = np.array([[True, True], [True, True], [True, False], [False, False]]) - axis_2 = 0 expect_2 = np.any(x_2, axis=axis_2, keepdims=False) - output_1 = net(Tensor(x_1), axis_1) - output_2 = net(Tensor(x_2), axis_2) + output1 = net1() + output2 = net2() - np.testing.assert_almost_equal(output_1.asnumpy(), expect_1) - np.testing.assert_almost_equal(output_2.asnumpy(), expect_2) + np.testing.assert_almost_equal(output1.asnumpy(), expect_1) + np.testing.assert_almost_equal(output2.asnumpy(), expect_2) diff --git a/tests/st/ops/gpu/test_reduce_max_op.py b/tests/st/ops/gpu/test_reduce_max_op.py index 46943b4bbbe..e44e7442726 100644 --- a/tests/st/ops/gpu/test_reduce_max_op.py +++ b/tests/st/ops/gpu/test_reduce_max_op.py @@ -179,36 +179,41 @@ def test_ReduceMax(): assert np.all(diff8 < error8) +x_1 = x8 +axis_1 = 0 +x_2 = x1 +axis_2 = 0 + + class ReduceMaxDynamic(nn.Cell): - def __init__(self): + def __init__(self, x, axis): super(ReduceMaxDynamic, self).__init__() self.reducemax = P.ReduceMax(False) self.test_dynamic = inner.GpuConvertToDynamicShape() + self.x = x + self.axis = axis - def construct(self, x, axis): - x = self.test_dynamic(x) - return self.reducemax(x, axis) + def construct(self): + dynamic_x = self.test_dynamic(self.x) + return self.reducemax(dynamic_x, self.axis) @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_reduce_max_dynamic(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - net = ReduceMaxDynamic() + net1 = ReduceMaxDynamic(Tensor(x_1), axis_1) + net2 = ReduceMaxDynamic(Tensor(x_2), axis_2) - x_1 = x8 - axis_1 = 0 expect_1 = np.max(x_1, axis=0, keepdims=False) - - x_2 = x1 - axis_2 = 0 expect_2 = np.max(x_2, axis=0, keepdims=False) - output_1 = net(Tensor(x_1), axis_1) - output_2 = net(Tensor(x_2), axis_2) + output1 = net1() + output2 = net2() + + np.testing.assert_almost_equal(output1.asnumpy(), expect_1) + np.testing.assert_almost_equal(output2.asnumpy(), expect_2) - np.testing.assert_almost_equal(output_1.asnumpy(), expect_1) - np.testing.assert_almost_equal(output_2.asnumpy(), expect_2) class ReduceMaxTypeNet(nn.Cell): def __init__(self, nptype): diff --git a/tests/st/ops/gpu/test_reduce_mean_op.py b/tests/st/ops/gpu/test_reduce_mean_op.py index 5b4c3396a87..8fb6abdc3c0 100644 --- a/tests/st/ops/gpu/test_reduce_mean_op.py +++ b/tests/st/ops/gpu/test_reduce_mean_op.py @@ -268,14 +268,16 @@ def test_ReduceMean(): assert output[14].shape == expect14.shape class ReduceMeanDynamic(nn.Cell): - def __init__(self, keepdims=False): + def __init__(self, x, axis, keepdims=False): super(ReduceMeanDynamic, self).__init__() self.test_dynamic = inner.GpuConvertToDynamicShape() self.reducemean = P.ReduceMean(keep_dims=keepdims) + self.x = x + self.axis = axis - def construct(self, input_x, axis): - input_x = self.test_dynamic(input_x) - output = self.reducemean(input_x, axis) + def construct(self): + dynamic_x = self.test_dynamic(self.x) + output = self.reducemean(dynamic_x, self.axis) return output @pytest.mark.level0 @@ -283,32 +285,30 @@ class ReduceMeanDynamic(nn.Cell): @pytest.mark.env_onecard def test_dynamic_reduce_mean_keepdims_true(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - net = ReduceMeanDynamic(keepdims=True) - x_tensor_1 = Tensor(x14) - output_1 = net(x_tensor_1, axis14) - x_tensor_2 = Tensor(x0) - output_2 = net(x_tensor_2, axis0) + net1 = ReduceMeanDynamic(Tensor(x14), axis14, keepdims=True) + net2 = ReduceMeanDynamic(Tensor(x0), axis0, keepdims=True) + output1 = net1() + output2 = net2() expect_1 = np.mean(x14, axis=np_axis14, keepdims=True) - diff_1 = abs(output_1.asnumpy() - expect_1) + diff_1 = abs(output1.asnumpy() - expect_1) error_1 = np.ones(shape=expect_1.shape) * 1.0e-5 assert np.all(diff_1 < error_1) - assert output_1.shape == expect_1.shape + assert output1.shape == expect_1.shape expect_2 = np.mean(x0, axis=axis0, keepdims=True) - diff_2 = abs(output_2.asnumpy() - expect_2) + diff_2 = abs(output2.asnumpy() - expect_2) error_2 = np.ones(shape=expect_2.shape) * 1.0e-5 assert np.all(diff_2 < error_2) - assert output_2.shape == expect_2.shape + assert output2.shape == expect_2.shape @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_dynamic_reduce_mean_keepdims_false(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - net = ReduceMeanDynamic(keepdims=False) - x_tensor = Tensor(x12) - output = net(x_tensor, axis12) + net = ReduceMeanDynamic(Tensor(x12), axis12, keepdims=False) + output = net() expect = np.mean(x12, axis=axis12, keepdims=False) diff = abs(output.asnumpy() - expect) diff --git a/tests/st/ops/gpu/test_reduce_min_op.py b/tests/st/ops/gpu/test_reduce_min_op.py index 5008e6116c0..9492f19962a 100644 --- a/tests/st/ops/gpu/test_reduce_min_op.py +++ b/tests/st/ops/gpu/test_reduce_min_op.py @@ -179,33 +179,37 @@ def test_ReduceMin(): assert np.all(diff8 < error8) +x_1 = x8 +axis_1 = 0 +x_2 = x1 +axis_2 = 0 + + class ReduceMinDynamic(nn.Cell): - def __init__(self): + def __init__(self, x, axis): super(ReduceMinDynamic, self).__init__() self.reducemin = P.ReduceMin(False) self.test_dynamic = inner.GpuConvertToDynamicShape() + self.x = x + self.axis = axis - def construct(self, x, axis): - x = self.test_dynamic(x) - return self.reducemin(x, axis) + def construct(self): + dynamic_x = self.test_dynamic(self.x) + return self.reducemin(dynamic_x, self.axis) @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_reduce_min_dynamic(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - net = ReduceMinDynamic() + net1 = ReduceMinDynamic(Tensor(x_1), axis_1) + net2 = ReduceMinDynamic(Tensor(x_2), axis_2) - x_1 = x8 - axis_1 = 0 expect_1 = np.min(x_1, axis=0, keepdims=False) - - x_2 = x1 - axis_2 = 0 expect_2 = np.min(x_2, axis=0, keepdims=False) - output_1 = net(Tensor(x_1), axis_1) - output_2 = net(Tensor(x_2), axis_2) + output1 = net1() + output2 = net2() - np.testing.assert_almost_equal(output_1.asnumpy(), expect_1) - np.testing.assert_almost_equal(output_2.asnumpy(), expect_2) + np.testing.assert_almost_equal(output1.asnumpy(), expect_1) + np.testing.assert_almost_equal(output2.asnumpy(), expect_2) diff --git a/tests/st/ops/gpu/test_reduce_sum_op.py b/tests/st/ops/gpu/test_reduce_sum_op.py index 708deb21845..878c9df2847 100644 --- a/tests/st/ops/gpu/test_reduce_sum_op.py +++ b/tests/st/ops/gpu/test_reduce_sum_op.py @@ -270,15 +270,23 @@ def test_ReduceSum(): assert output[14].shape == expect14.shape +x_1 = x8 +axis_1 = 0 +x_2 = x1 +axis_2 = 0 + + class ReduceSumDynamic(nn.Cell): - def __init__(self): + def __init__(self, x, axis): super(ReduceSumDynamic, self).__init__() self.reducesum = P.ReduceSum(True) self.test_dynamic = inner.GpuConvertToDynamicShape() + self.x = x + self.axis = axis - def construct(self, x, axis): - x = self.test_dynamic(x) - return self.reducesum(x, axis) + def construct(self): + dynamic_x = self.test_dynamic(self.x) + return self.reducesum(dynamic_x, self.axis) @pytest.mark.level0 @@ -286,21 +294,18 @@ class ReduceSumDynamic(nn.Cell): @pytest.mark.env_onecard def test_reduce_sum_dynamic(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - net = ReduceSumDynamic() + net1 = ReduceSumDynamic(Tensor(x_1), axis_1) + net2 = ReduceSumDynamic(Tensor(x_2), axis_2) - x_1 = x8 - axis_1 = 0 expect_1 = np.sum(x_1, axis=axis_1, keepdims=True) - - x_2 = x1 - axis_2 = 0 expect_2 = np.sum(x_2, axis=axis_2, keepdims=True) - output_1 = net(Tensor(x_1), axis_1) - output_2 = net(Tensor(x_2), axis_2) + output1 = net1() + output2 = net2() + + np.testing.assert_almost_equal(output1.asnumpy(), expect_1) + np.testing.assert_almost_equal(output2.asnumpy(), expect_2) - np.testing.assert_almost_equal(output_1.asnumpy(), expect_1) - np.testing.assert_almost_equal(output_2.asnumpy(), expect_2) class ReduceSumTypeNet(nn.Cell): def __init__(self, nptype): diff --git a/tests/ut/cpp/abstract/utils_test.cc b/tests/ut/cpp/abstract/utils_test.cc index ff44c1c0409..ea954c0641f 100644 --- a/tests/ut/cpp/abstract/utils_test.cc +++ b/tests/ut/cpp/abstract/utils_test.cc @@ -32,18 +32,26 @@ TEST_F(TestUtils, test_join) { AbstractBasePtr abs_s1 = FromValue(static_cast(1), false); AbstractBasePtr abs_s2 = FromValue(static_cast(2), false); AbstractBasePtr abs_s_anything = FromValue(static_cast(2), true); - abs_s_anything->set_value(kAnyValue); AbstractBasePtr res_s1 = abs_s1->Join(abs_s2); ASSERT_EQ(*res_s1, *abs_s_anything); + // AbstractTuple join; + std::vector list1 = {1, 2, 3, 4, 5}; + std::vector list2 = {5, 4, 3, 2, 1}; + AbstractBasePtr abs_t1 = FromValue(list1, true); + AbstractBasePtr abs_t2 = FromValue(list2, true); + + AbstractBasePtr res_t1 = abs_t1->Join(abs_t2); + ASSERT_EQ(res_t1, abs_t1); + abs_s1 = FromValue(static_cast(1), false); AbstractBasePtr t1 = std::make_shared(AbstractBasePtrList({abs_s1, abs_s_anything})); AbstractBasePtr t2 = std::make_shared(AbstractBasePtrList({abs_s1, abs_s_anything})); AbstractBasePtr t3 = std::make_shared(AbstractBasePtrList({abs_s_anything, abs_s_anything})); - AbstractBasePtr res_t1 = t1->Join(t2); + res_t1 = t1->Join(t2); ASSERT_EQ(res_t1, t1); res_t1 = t1->Join(t3); diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index 38881362d5a..3c794f97a8d 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -111,11 +111,8 @@ TEST_F(TestOptLib, test_inline) { // add infer and renormalize std::shared_ptr res = std::make_shared(); AbstractBasePtrList args_spec_list; - tensor::TensorPtr x_tensor = std::make_shared(kFloat32->type_id(), std::vector{2, 3}); - tensor::TensorPtr y_tensor = std::make_shared(kFloat32->type_id(), std::vector{2, 3}); - - AbstractBasePtr abstract_v1 = abstract::FromValue(x_tensor, true); - AbstractBasePtr abstract_v2 = abstract::FromValue(y_tensor, true); + AbstractBasePtr abstract_v1 = abstract::FromValue(static_cast(1), true); + AbstractBasePtr abstract_v2 = abstract::FromValue(static_cast(2), true); args_spec_list.push_back(abstract_v1); args_spec_list.push_back(abstract_v2); AnalysisResult result = pipeline::AbstractAnalyze(res, before1, args_spec_list); diff --git a/tests/ut/cpp/pipeline/static_analysis/data_test.cc b/tests/ut/cpp/pipeline/static_analysis/data_test.cc index 248bf362bbe..5c333ed52f9 100644 --- a/tests/ut/cpp/pipeline/static_analysis/data_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/data_test.cc @@ -184,7 +184,7 @@ TEST_F(TestData, test_broaden) { AbstractBasePtr s2 = s1->Broaden(); ASSERT_TRUE(*s1->GetTypeTrack() == *s2->GetTypeTrack()); ASSERT_TRUE(*s1->GetValueTrack() == *MakeValue(int1)); - ASSERT_TRUE(s2->GetValueTrack()->isa()); + ASSERT_TRUE(s2->GetValueTrack()->isa()); AbstractFunctionPtr f1 = std::make_shared(std::make_shared(), AnalysisContext::DummyContext()); @@ -196,7 +196,7 @@ TEST_F(TestData, test_broaden) { AbstractList* l2_cast = dynamic_cast(l2.get()); ASSERT_TRUE(l2_cast != nullptr); AbstractBasePtr csr = AbstractJoin(l2_cast->elements()); - ASSERT_TRUE(csr->GetValueTrack()->isa()); + ASSERT_TRUE(csr->GetValueTrack()->isa()); } } // namespace abstract diff --git a/tests/ut/python/ops/test_tensor_slice.py b/tests/ut/python/ops/test_tensor_slice.py index 5e1c02aa0e5..de8190d0ccb 100644 --- a/tests/ut/python/ops/test_tensor_slice.py +++ b/tests/ut/python/ops/test_tensor_slice.py @@ -20,7 +20,6 @@ from mindspore import Tensor, Parameter from mindspore import context from mindspore import dtype as mstype from mindspore.nn import Cell -from mindspore.ops import operations as P from ....mindspore_test_framework.mindspore_test import mindspore_test from ....mindspore_test_framework.pipeline.forward.compile_forward \ import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \ @@ -684,27 +683,6 @@ def test_tensor_assign_bool_index(): net4(Ta, Tensor(u_scalar, mstype.int32)) -def test_trivial_call_function_twice_with_diff_key_value_para(): - class Net(Cell): - def __init__(self): - super(Net, self).__init__() - self.arange = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])) - self.concat = P.Concat(axis=0) - - def compute(self, x, is_decoder): - if is_decoder: - return self.arange[:x] - return self.arange[1:x + 1] - - def construct(self): - result1 = self.compute(7, is_decoder=True) - result2 = self.compute(6, is_decoder=False) - return self.concat((result1, result2)) - - net = Net() - net() - - test_cases = [ ('TensorAssignWithTupleEllipsis2', { 'block': TensorAssignWithTupleEllipsis2(), diff --git a/tests/ut/python/pipeline/parse/test_ms_function_pass_non_tensor_inputs.py b/tests/ut/python/pipeline/parse/test_ms_function_pass_non_tensor_inputs.py index 3267b3fa900..f56c3c4bbb7 100644 --- a/tests/ut/python/pipeline/parse/test_ms_function_pass_non_tensor_inputs.py +++ b/tests/ut/python/pipeline/parse/test_ms_function_pass_non_tensor_inputs.py @@ -19,7 +19,7 @@ from mindspore import Tensor, ms_function from mindspore import context from mindspore.ops import operations as P -context.set_context(mode=context.PYNATIVE_MODE, save_graphs=True) +context.set_context(mode=context.PYNATIVE_MODE) @ms_function @@ -33,8 +33,7 @@ def test_scalar_compute(): p = (3, 4) q = [5, 6] w = {"x": 7, "y": 8} - ret = compute(int_x, int_y, p, q, w) - assert ret == -1 + compute(int_x, int_y, p, q, w) def test_tensor_compute(): diff --git a/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py b/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py index 246370cc1e2..378e7517c60 100644 --- a/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py +++ b/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py @@ -45,6 +45,17 @@ class GradNet(nn.Cell): return self.grad_all(self.forward_net)(tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag) +class GradNet1(nn.Cell): + def __init__(self, net, get_all): + super(GradNet1, self).__init__() + self.forward_net = net + self.sens = Tensor(np.ones((2, 2), np.float32) * 5) + self.grad_all = C.GradOperation(get_all=get_all) + + def construct(self, tuple_a, tensor_x, list_b, tensor_y, tensor_z, dict_c): + return self.grad_all(self.forward_net)(tuple_a, tensor_x, list_b, tensor_y, tensor_z, dict_c) + + x = Tensor(np.ones((2, 2), np.float32)) y = Tensor(np.ones((2, 2), np.float32) * 2) z = Tensor(np.ones((2, 2), np.float32) * 3) @@ -68,33 +79,18 @@ forward_net = FirstInputTupleNet() grad_all_inputs_net = GradNet(forward_net, get_all=True) -def test_outermost_net_inputs_including_non_tensor(): - forward_net(arg_t0, z, arg_l0, w, sl, args_d0, flag_0) - forward_net(arg_t1, z, arg_l1, x, sl, args_d1, flag_1) - - -def test_grad_net_inputs_including_non_tensor(): - assert len(grad_all_inputs_net(arg_t0, z, arg_l0, w, sl, args_d0, flag_0)) == 2 - assert len(grad_all_inputs_net(arg_t1, z, arg_l1, x, sl, args_d1, flag_1)) == 2 - - def test_grad_first_input_net(): class FirstInputTensorNet(nn.Cell): def __init__(self): super(FirstInputTensorNet, self).__init__() - def construct(self, tensor_x, tuple_a, list_b, tensor_y, scalar, dict_c, flag): - if flag: - return tensor_x - tuple_a[2] + list_b[1][1]["x"] - tensor_y + scalar - dict_c["x"] - return tensor_x + tuple_a[2] - list_b[1][1]["y"] + tensor_y - scalar + dict_c["y"] + def construct(self, tensor_x, tuple_a, list_b, tensor_y, tensor_z, dict_c): + return tensor_x + tuple_a[2] - list_b[1][1]["y"] + tensor_y - tensor_z + dict_c["y"] - grad_fist_input_tensor_net = GradNet(FirstInputTensorNet(), get_all=False) - ret = grad_fist_input_tensor_net(z, arg_t0, arg_l0, w, sl, args_d0, flag_0) + grad_fist_input_tensor_net = GradNet1(FirstInputTensorNet(), get_all=False) + ret = grad_fist_input_tensor_net(z, arg_t0, arg_l0, w, y, args_d0) assert np.allclose(ret.asnumpy(), np.ones((2, 2), np.float32)) - grad_fist_input_tuple_net = GradNet(forward_net, get_all=False) - assert not grad_fist_input_tuple_net(arg_t0, z, arg_l0, w, sl, args_d0, flag_0) - def test_net_inputs_including_str(): with pytest.raises(TypeError) as err: diff --git a/tests/ut/python/pynative_mode/test_framstruct.py b/tests/ut/python/pynative_mode/test_framstruct.py index 781dc306980..5140960451c 100644 --- a/tests/ut/python/pynative_mode/test_framstruct.py +++ b/tests/ut/python/pynative_mode/test_framstruct.py @@ -14,7 +14,7 @@ # ============================================================================ """ test_framstruct """ import numpy as np - +import pytest import mindspore as ms import mindspore.nn as nn from mindspore import context @@ -76,11 +76,13 @@ def dynamic_make_tuple(x, lower, upper): def test_dynamic_make_tuple(): - assert dynamic_make_tuple(2, 1, 5) == (2, 2, 2, 2) + # Dynamically recursively creating static type is invalid in mindspore, as mindspore is a static language. + with pytest.raises(RuntimeError): + dynamic_make_tuple(2, 1, 5) def test_make_tuple(): - # Staticly recursively creating static type is valid in mindspore. + # Statically recursively creating static type is valid in mindspore. @ms_function def make_tuple(x): out = ()