diff --git a/.jenkins/check/config/filter_pylint.txt b/.jenkins/check/config/filter_pylint.txt index 1b1e8268c3c..f277cc0ee80 100644 --- a/.jenkins/check/config/filter_pylint.txt +++ b/.jenkins/check/config/filter_pylint.txt @@ -25,7 +25,7 @@ "mindspore/mindspore/python/mindspore/ops/_op_impl/_custom_op" "dangerous-default-value "mindspore/mindspore/python/mindspore/ops/_op_impl/_custom_op" "simplifiable-if-expression" "mindspore/mindspore/python/mindspore/ops/_op_impl/_custom_op" "unused-variable" -"mindspore/mindspore/python/mindspore/ops/composite/base.py" "protected-acces" +"mindspore/mindspore/python/mindspore/ops/composite/base.py" "protected-access" "mindspore/mindspore/python/mindspore/ops/primitive.py" "assignment-from-none" "mindspore/mindspore/python/mindspore/nn/cell.py" "assignment-from-none" "mindspore/mindspore/python/mindspore/_extends/parse/resources.py" "bad-whitespace" diff --git a/mindspore/ccsrc/frontend/optimizer/ad/pynative_dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/pynative_dfunctor.cc index 0885ec83e00..8840a3c776d 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/pynative_dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/pynative_dfunctor.cc @@ -61,12 +61,21 @@ ValueNodePtr PynativeDFunctor::GenNewTensor(const CNodePtr &cnode_morph) { if (output_values.empty()) { MS_LOG(EXCEPTION) << "The output values is empty, cnode morph: " << cnode_morph->DebugString(); } - return NewValueNode(std::make_shared(output_values)); + auto value_tuple = std::make_shared(output_values); + auto value_node = NewValueNode(value_tuple); + value_node->set_abstract(value_tuple->ToAbstract()->Broaden()); + return value_node; } else if (cnode_type->isa()) { - return NewValueNode(GenNewTensorInner(cnode_type, cnode_shape)); + auto tensor_value = GenNewTensorInner(cnode_type, cnode_shape); + auto value_node = NewValueNode(tensor_value); + value_node->set_abstract(tensor_value->ToAbstract()->Broaden()); + return value_node; } else if (cnode_shape->isa()) { ShapeVector NoShape; - return NewValueNode(std::make_shared(cnode_type->type_id(), NoShape)); + auto tensor_value = std::make_shared(cnode_type->type_id(), NoShape); + auto value_node = NewValueNode(tensor_value); + value_node->set_abstract(tensor_value->ToAbstract()->Broaden()); + return value_node; } MS_LOG(EXCEPTION) << "Unknown shape: " << cnode_shape->ToString() << ", type: " << cnode_type->ToString(); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 0be7e3db32c..8a8ed5cf77a 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -2072,7 +2072,7 @@ void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr MS_EXCEPTION_IF_NULL(value_node); TensorValueToTensor(value_node->value(), &tensors_in_bprop_graph); } - + // Check exception case. auto &tensor_id_with_tensor_object = top_cell()->tensor_id_with_tensor_object(); if (!tensor_id_with_tensor_object.empty()) { MS_LOG(EXCEPTION) << "When compile a top graph, the tensor_id_with_tensor_object map should be empty. Top cell: " diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_cpu_kernel.cc index df0de37ae47..16afdc4daeb 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_cpu_kernel.cc @@ -97,12 +97,17 @@ class ArithmeticCpuTypeFunc : public CpuKernelFunc { } size_t l = input_shape1_.size(); - for (size_t i = 0; i < output_shape_.size() - l; ++i) { - (void)input_shape1_.insert(input_shape1_.begin(), 1); + if (l < output_shape_.size()) { + for (size_t i = 0; i < output_shape_.size() - l; ++i) { + (void)input_shape1_.insert(input_shape1_.begin(), 1); + } } + l = input_shape2_.size(); - for (size_t i = 0; i < output_shape_.size() - l; ++i) { - (void)input_shape2_.insert(input_shape2_.begin(), 1); + if (l < output_shape_.size()) { + for (size_t i = 0; i < output_shape_.size() - l; ++i) { + (void)input_shape2_.insert(input_shape2_.begin(), 1); + } } CPUKernelUtils::GetElementNumEveryDim(input_shape1_, &input_element_num1_); CPUKernelUtils::GetElementNumEveryDim(input_shape2_, &input_element_num2_); diff --git a/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.cc b/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.cc index 67fb5f93824..fcfa05328de 100644 --- a/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.cc +++ b/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.cc @@ -422,13 +422,12 @@ void GPUDeviceContext::UpdateDynamicShape(const CNodePtr &kernel) const { return; } - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - if (session::AnfRuntimeAlgorithm::GetKernelType(kernel) == KernelType::AKG_KERNEL) { MS_LOG(EXCEPTION) << "Akg kernel do not support dynamic shape: " << kernel->fullname_with_scope(); } + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); kernel::NativeGpuKernelMod *gpu_kernel = dynamic_cast(kernel_mod); MS_EXCEPTION_IF_NULL(gpu_kernel); diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.cc index 9cd15dea40c..8c6302b4141 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.cc @@ -169,14 +169,14 @@ bool Copy(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_ // Exist the size alignment in some device, so get the min device size. size_t copy_size = std::min(src_device_tensor->GetSize(), dst_device_tensor->GetSize()); - if (src_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) { + if (dst_device_tensor->DeviceType() == src_device_tensor->DeviceType()) { + return dst_device_tensor->SyncDeviceToDevice(src_device_tensor); + } else if (src_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) { // CPU device tensor copy to other device tensor. return dst_device_tensor->SyncHostToDevice(copy_size, src_device_tensor->GetPtr()); } else if (dst_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) { // Other device tensor copy to CPU device tensor. return src_device_tensor->SyncDeviceToHost(copy_size, dst_device_tensor->GetMutablePtr()); - } else if (dst_device_tensor->DeviceType() == src_device_tensor->DeviceType()) { - return dst_device_tensor->SyncDeviceToDevice(src_device_tensor); } else { MS_LOG(ERROR) << "Invalid device type, src device type: " << src_device_tensor->DeviceType() << ", dst device type: " << dst_device_tensor->DeviceType(); diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/data_prepare_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/data_prepare_actor.cc index a928d94688c..f4759ba3909 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/data_prepare_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/data_prepare_actor.cc @@ -410,7 +410,7 @@ void DataPrepareActor::PrepareDataForHostTensorQueue(const std::vectorFetchNodePosition(input_node); @@ -423,12 +423,19 @@ void DataPrepareActor::PrepareDataForHostTensorQueue(const std::vector(input_tensor->device_address()); auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false); MS_EXCEPTION_IF_NULL(device_address); + // Passthrough input node: graph(..., input_node, ...) -> return(..., input_node, ...) + auto passthrough_inp_node = + graph->GetFrontNodeWithIndexByGraphOutput(std::make_pair(input_node, 0)).first != nullptr; + // In order to avoid the device ptr_ being hold by the input tensor and the output tensor, the tensor address + // cannot be directly set to the passthrough input node. The 'ptr_' of passthrough input node is re-malloced and + // device to device copy by input tensor address. if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType()) && - !device_address->is_ptr_persisted() && tensor_address->format() == device_address->format()) { + !device_address->is_ptr_persisted() && tensor_address->format() == device_address->format() && + !passthrough_inp_node) { AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get()); tensor_address->SetNodeIndex(input_node, 0); } - device_address->SetSize(host_tensors[tensor_position]->data().nbytes()); + device_address->SetSize(input_tensor->data().nbytes()); } } diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/output_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/output_actor.cc index 779df6fd0ea..4d5990d9951 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/output_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/output_actor.cc @@ -37,9 +37,9 @@ bool IsOutputAddressPersisted(const DeviceTensor *output_device_tensor, const An return true; } - // In the input as output scenario, the output device tensor may come from the input tensor and can't be replaced. - // But in the dynamic shape scenario, need to free the old memory and alloc new memory using the new shape size. - if (output_node->isa() && !(output_node->cast()->has_dynamic_shape())) { + // In the input as output scenario. + // If the input node is the weight of graph, its output device tensor does not need to be replaced. + if (output_node->isa() && common::AnfAlgo::IsParameterWeight(output_node->cast())) { return true; } diff --git a/mindspore/python/mindspore/_checkparam.py b/mindspore/python/mindspore/_checkparam.py index cd31a46ffe3..c34d457f36c 100644 --- a/mindspore/python/mindspore/_checkparam.py +++ b/mindspore/python/mindspore/_checkparam.py @@ -950,6 +950,33 @@ class Validator: raise TypeError(f"For COOTensor, `indices` must have int16 or int32 or int64 data type, but got " \ f"{indices_dtype}.") + @staticmethod + def check_dynamic_shape(dyn_inputs, actual_inputs): + """Check the consistency of dynamic shape tensors and actual input tensors.""" + dyn_inputs_size = len(dyn_inputs) + actual_inputs_size = len(actual_inputs) + if dyn_inputs_size != actual_inputs_size: + raise ValueError(f"The number of actual input tensors: {actual_inputs_size} is not equal to the number of " + f"dynamic shape tensors: {dyn_inputs_size}.") + for i in range(dyn_inputs_size): + if dyn_inputs[i].dtype is not actual_inputs[i].dtype: + raise TypeError(f"The data type of index `{i}` args in actual input tensors should be " + f"`{dyn_inputs[i].dtype}`, but got `{actual_inputs[i].dtype}`.") + if len(dyn_inputs[i].shape) != len(actual_inputs[i].shape): + raise ValueError(f"The dimension of index `{i}` args in actual input tensors should be " + f"`{len(dyn_inputs[i].shape)}`, but got `{len(actual_inputs[i].shape)}`.") + check_dyn_shape_value_equal(i, dyn_inputs[i].shape, actual_inputs[i].shape) + return True + + +def check_dyn_shape_value_equal(index, dyn_shape, actual_shape): + """Check the consistency of dynamic shape and actual input shape.""" + for i, x in enumerate(dyn_shape): + if x not in (-1, actual_shape[i]): + raise ValueError(f"The {i}th value in shape of index `{index}` args should be `{x}`, but got " + f"`{actual_shape[i]}`.") + return True + def check_input_format(input_param): """Judge input format.""" diff --git a/mindspore/python/mindspore/common/api.py b/mindspore/python/mindspore/common/api.py index f8c2542e423..c71814d116a 100644 --- a/mindspore/python/mindspore/common/api.py +++ b/mindspore/python/mindspore/common/api.py @@ -25,7 +25,7 @@ import importlib from collections import OrderedDict from functools import wraps import numpy as np - +import mindspore as ms from mindspore import context from mindspore import log as logger from mindspore._extends.remote import kernel_build_server @@ -249,7 +249,6 @@ class _MindsporeFunctionExecutor: self._graph_executor.updata_param_node_default_input(phase, new_param) obj.load_parameter_slice(None) - if _pynative_executor.get_optimizer(): params = obj.trainable_params() opt_params = _pynative_executor.get_optimizer().trainable_params() @@ -282,17 +281,8 @@ class _MindsporeFunctionExecutor: logger.warning(f"For 'Cell', it's not support hook function when using ms_function. If you want to " f"use hook function, please use context.set_context to set pynative mode and remove " f"`ms_function`.") - # Verify the signature for both function and method - if self.input_signature is not None: - signatures = [] - for sig_spec in self.input_signature: - if not isinstance(sig_spec, MetaTensor): - raise TypeError("Input_signature is not MetaTensor") - signatures.append(sig_spec) - is_valid_input = verify_inputs_signature(signatures, args_list) - if not is_valid_input: - raise ValueError("Inputs is incompatible with input signature!") - + # Chose dynamic shape tensors or actual input tensors as compile args. + compile_args = self._generate_compile_args(args_list) generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + \ str(self.fn.__code__.co_firstlineno) + '.' + str(id(self.fn)) if _pynative_executor.grad_flag(): @@ -314,7 +304,7 @@ class _MindsporeFunctionExecutor: self.enable_tuple_broaden = self.obj.enable_tuple_broaden self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden) - key = self._graph_executor.generate_arguments_key(args_list, self.enable_tuple_broaden) + key = self._graph_executor.generate_arguments_key(compile_args, self.enable_tuple_broaden) phase = generate_name + '.' + str(key) if phase in ms_compile_cache: return phase @@ -323,10 +313,10 @@ class _MindsporeFunctionExecutor: self._set_compile_cache_dep_files() if self.obj is None: - is_compile = self._graph_executor.compile(self.fn, args_list, phase, True) + is_compile = self._graph_executor.compile(self.fn, compile_args, phase, True) else: self._graph_executor.set_weights_values(self.obj.parameters_dict()) - is_compile = self._graph_executor.compile(self.obj, args_list, phase, True) + is_compile = self._graph_executor.compile(self.obj, compile_args, phase, True) if is_pynative_parallel(): self._parallel_process_for_ms_function(phase) @@ -374,6 +364,43 @@ class _MindsporeFunctionExecutor: return output + def _generate_compile_args(self, args_list): + """Chose dynamic shape tensors or actual input tensors as compile args.""" + compile_args = args_list + # Case: The `set_inputs()` of Cell object has been set, using these dynamic shape args as compile args. + if isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs(): + compile_args = self.obj.get_inputs() + for args in compile_args: + if not isinstance(args, MsTensor): + raise TypeError(f"The args in `set_inputs()` of Cell object should be a Tensor, " + f"but got {type(args)}.") + Validator.check_dynamic_shape(compile_args, args_list) + # Case: The dynamic shape tensors have been assigned to `input_signature`, they are preferred as compile args. + if self.input_signature is not None: + if not isinstance(self.input_signature, (tuple, list)): + self.input_signature = (self.input_signature,) + self.input_signature = list(self.input_signature) + dyn_shape = False + for sig_args in self.input_signature: + if not isinstance(sig_args, (MetaTensor, MsTensor)): + raise TypeError(f"The args in `input_signature` of `ms_function` should be a Tensor, " + f"but got {type(sig_args)}.") + if -1 in sig_args.shape: + dyn_shape = True + if not dyn_shape: + if not verify_inputs_signature(self.input_signature, args_list): + raise ValueError("The input args is incompatible with the args in `input_signature`!") + else: + # Checkout whether the `sens` has been added to args_list. + if len(self.input_signature) == len(args_list) - 1: + logger.warning(f"The number of actual input args `{len(args_list)}` is one more than the number " + f"of dynamic shape args `{len(self.input_signature)}`. The last actual args may be " + f" `sens` and added to compile args.") + self.input_signature.append(args_list[-1]) + Validator.check_dynamic_shape(self.input_signature, args_list) + compile_args = self.input_signature + return tuple(compile_args) + def ms_function(fn=None, obj=None, input_signature=None): """ @@ -932,9 +959,6 @@ class _CellGraphExecutor: self._graph_executor.set_queue_name(queue_name) return True - def _build_data_graph(self, obj, phase): - self._graph_executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict()) - def set_queue_name(self, queue_name): """ while a mode use shared dataset with others, need set queue_name which saved in data_set @@ -952,6 +976,9 @@ class _CellGraphExecutor: else: _set_dataset_mode_config('normal') + def _build_data_graph(self, obj, phase): + self._graph_executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict()) + @staticmethod def _use_vm_mode(): enable_ge = context.get_context("enable_ge") diff --git a/mindspore/python/mindspore/ops/composite/base.py b/mindspore/python/mindspore/ops/composite/base.py index 495323ede99..2e9f094c275 100644 --- a/mindspore/python/mindspore/ops/composite/base.py +++ b/mindspore/python/mindspore/ops/composite/base.py @@ -18,7 +18,7 @@ """Basic composite operations.""" from functools import partial from types import FunctionType - +import mindspore as ms from mindspore import context from ..._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \ TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, \ @@ -368,12 +368,15 @@ class GradOperation(GradOperation_): # In pure PYNATIVE_MODE the out layer after_grad just used to set pynative flag for inner GradOperation. # In PYNATIVE_MODE calling Grad from ms_function, use the out layer after_grad do grad in GRAPH_MODE. if context.get_context("mode") == context.GRAPH_MODE: + dynamic_shape_inputs = None + if isinstance(fn, ms.nn.Cell): + dynamic_shape_inputs = fn.get_inputs() if self.get_by_list: - @ms_function + @ms_function(input_signature=dynamic_shape_inputs) def after_grad(*args): return grad_(fn, weights)(*args) else: - @ms_function + @ms_function(input_signature=dynamic_shape_inputs) def after_grad(*args): return grad_(fn)(*args) elif self.pynative_: @@ -461,17 +464,20 @@ class _Grad(GradOperation_): # In pure PYNATIVE_MODE the out layer after_grad just used to set pynative flag for inner GradOperation. # In PYNATIVE_MODE calling Grad from ms_function, use the out layer after_grad do grad in GRAPH_MODE. if context.get_context("mode") == context.GRAPH_MODE: + dynamic_shape_inputs = None + if isinstance(fn, ms.nn.Cell): + dynamic_shape_inputs = fn.get_inputs() if self.get_by_position: - @ms_function + @ms_function(input_signature=dynamic_shape_inputs) def after_grad(*args): return grad_(fn, weights, grad_position)(*args) else: if self.get_by_list: - @ms_function + @ms_function(input_signature=dynamic_shape_inputs) def after_grad(*args): return grad_(fn, weights)(*args) else: - @ms_function + @ms_function(input_signature=dynamic_shape_inputs) def after_grad(*args): return grad_(fn)(*args) elif self.pynative_: