forked from mindspore-Ecosystem/mindspore
enable_dynamic_shape_in_ms_function
This commit is contained in:
parent
430e8b7d53
commit
421d0cf339
|
@ -25,7 +25,7 @@
|
|||
"mindspore/mindspore/python/mindspore/ops/_op_impl/_custom_op" "dangerous-default-value
|
||||
"mindspore/mindspore/python/mindspore/ops/_op_impl/_custom_op" "simplifiable-if-expression"
|
||||
"mindspore/mindspore/python/mindspore/ops/_op_impl/_custom_op" "unused-variable"
|
||||
"mindspore/mindspore/python/mindspore/ops/composite/base.py" "protected-acces"
|
||||
"mindspore/mindspore/python/mindspore/ops/composite/base.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/ops/primitive.py" "assignment-from-none"
|
||||
"mindspore/mindspore/python/mindspore/nn/cell.py" "assignment-from-none"
|
||||
"mindspore/mindspore/python/mindspore/_extends/parse/resources.py" "bad-whitespace"
|
||||
|
|
|
@ -61,12 +61,21 @@ ValueNodePtr PynativeDFunctor::GenNewTensor(const CNodePtr &cnode_morph) {
|
|||
if (output_values.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The output values is empty, cnode morph: " << cnode_morph->DebugString();
|
||||
}
|
||||
return NewValueNode(std::make_shared<ValueTuple>(output_values));
|
||||
auto value_tuple = std::make_shared<ValueTuple>(output_values);
|
||||
auto value_node = NewValueNode(value_tuple);
|
||||
value_node->set_abstract(value_tuple->ToAbstract()->Broaden());
|
||||
return value_node;
|
||||
} else if (cnode_type->isa<TensorType>()) {
|
||||
return NewValueNode(GenNewTensorInner(cnode_type, cnode_shape));
|
||||
auto tensor_value = GenNewTensorInner(cnode_type, cnode_shape);
|
||||
auto value_node = NewValueNode(tensor_value);
|
||||
value_node->set_abstract(tensor_value->ToAbstract()->Broaden());
|
||||
return value_node;
|
||||
} else if (cnode_shape->isa<abstract::NoShape>()) {
|
||||
ShapeVector NoShape;
|
||||
return NewValueNode(std::make_shared<tensor::Tensor>(cnode_type->type_id(), NoShape));
|
||||
auto tensor_value = std::make_shared<tensor::Tensor>(cnode_type->type_id(), NoShape);
|
||||
auto value_node = NewValueNode(tensor_value);
|
||||
value_node->set_abstract(tensor_value->ToAbstract()->Broaden());
|
||||
return value_node;
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "Unknown shape: " << cnode_shape->ToString() << ", type: " << cnode_type->ToString();
|
||||
|
|
|
@ -2072,7 +2072,7 @@ void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr
|
|||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
TensorValueToTensor(value_node->value(), &tensors_in_bprop_graph);
|
||||
}
|
||||
|
||||
// Check exception case.
|
||||
auto &tensor_id_with_tensor_object = top_cell()->tensor_id_with_tensor_object();
|
||||
if (!tensor_id_with_tensor_object.empty()) {
|
||||
MS_LOG(EXCEPTION) << "When compile a top graph, the tensor_id_with_tensor_object map should be empty. Top cell: "
|
||||
|
|
|
@ -97,12 +97,17 @@ class ArithmeticCpuTypeFunc : public CpuKernelFunc {
|
|||
}
|
||||
|
||||
size_t l = input_shape1_.size();
|
||||
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
|
||||
(void)input_shape1_.insert(input_shape1_.begin(), 1);
|
||||
if (l < output_shape_.size()) {
|
||||
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
|
||||
(void)input_shape1_.insert(input_shape1_.begin(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
l = input_shape2_.size();
|
||||
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
|
||||
(void)input_shape2_.insert(input_shape2_.begin(), 1);
|
||||
if (l < output_shape_.size()) {
|
||||
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
|
||||
(void)input_shape2_.insert(input_shape2_.begin(), 1);
|
||||
}
|
||||
}
|
||||
CPUKernelUtils::GetElementNumEveryDim(input_shape1_, &input_element_num1_);
|
||||
CPUKernelUtils::GetElementNumEveryDim(input_shape2_, &input_element_num2_);
|
||||
|
|
|
@ -422,13 +422,12 @@ void GPUDeviceContext::UpdateDynamicShape(const CNodePtr &kernel) const {
|
|||
return;
|
||||
}
|
||||
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
|
||||
if (session::AnfRuntimeAlgorithm::GetKernelType(kernel) == KernelType::AKG_KERNEL) {
|
||||
MS_LOG(EXCEPTION) << "Akg kernel do not support dynamic shape: " << kernel->fullname_with_scope();
|
||||
}
|
||||
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
kernel::NativeGpuKernelMod *gpu_kernel = dynamic_cast<kernel::NativeGpuKernelMod *>(kernel_mod);
|
||||
MS_EXCEPTION_IF_NULL(gpu_kernel);
|
||||
|
||||
|
|
|
@ -169,14 +169,14 @@ bool Copy(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_
|
|||
// Exist the size alignment in some device, so get the min device size.
|
||||
size_t copy_size = std::min(src_device_tensor->GetSize(), dst_device_tensor->GetSize());
|
||||
|
||||
if (src_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
|
||||
if (dst_device_tensor->DeviceType() == src_device_tensor->DeviceType()) {
|
||||
return dst_device_tensor->SyncDeviceToDevice(src_device_tensor);
|
||||
} else if (src_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
|
||||
// CPU device tensor copy to other device tensor.
|
||||
return dst_device_tensor->SyncHostToDevice(copy_size, src_device_tensor->GetPtr());
|
||||
} else if (dst_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
|
||||
// Other device tensor copy to CPU device tensor.
|
||||
return src_device_tensor->SyncDeviceToHost(copy_size, dst_device_tensor->GetMutablePtr());
|
||||
} else if (dst_device_tensor->DeviceType() == src_device_tensor->DeviceType()) {
|
||||
return dst_device_tensor->SyncDeviceToDevice(src_device_tensor);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Invalid device type, src device type: " << src_device_tensor->DeviceType()
|
||||
<< ", dst device type: " << dst_device_tensor->DeviceType();
|
||||
|
|
|
@ -410,7 +410,7 @@ void DataPrepareActor::PrepareDataForHostTensorQueue(const std::vector<std::vect
|
|||
input_tensor == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Synchronize dynamic shape info of the input tensor to the parameter node of graph.
|
||||
UpdateDynamicShape(input_node, input_tensor);
|
||||
|
||||
auto tensor_position = host_data_source_actor_->FetchNodePosition(input_node);
|
||||
|
@ -423,12 +423,19 @@ void DataPrepareActor::PrepareDataForHostTensorQueue(const std::vector<std::vect
|
|||
auto tensor_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address());
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
// Passthrough input node: graph(..., input_node, ...) -> return(..., input_node, ...)
|
||||
auto passthrough_inp_node =
|
||||
graph->GetFrontNodeWithIndexByGraphOutput(std::make_pair(input_node, 0)).first != nullptr;
|
||||
// In order to avoid the device ptr_ being hold by the input tensor and the output tensor, the tensor address
|
||||
// cannot be directly set to the passthrough input node. The 'ptr_' of passthrough input node is re-malloced and
|
||||
// device to device copy by input tensor address.
|
||||
if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType()) &&
|
||||
!device_address->is_ptr_persisted() && tensor_address->format() == device_address->format()) {
|
||||
!device_address->is_ptr_persisted() && tensor_address->format() == device_address->format() &&
|
||||
!passthrough_inp_node) {
|
||||
AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get());
|
||||
tensor_address->SetNodeIndex(input_node, 0);
|
||||
}
|
||||
device_address->SetSize(host_tensors[tensor_position]->data().nbytes());
|
||||
device_address->SetSize(input_tensor->data().nbytes());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -37,9 +37,9 @@ bool IsOutputAddressPersisted(const DeviceTensor *output_device_tensor, const An
|
|||
return true;
|
||||
}
|
||||
|
||||
// In the input as output scenario, the output device tensor may come from the input tensor and can't be replaced.
|
||||
// But in the dynamic shape scenario, need to free the old memory and alloc new memory using the new shape size.
|
||||
if (output_node->isa<Parameter>() && !(output_node->cast<ParameterPtr>()->has_dynamic_shape())) {
|
||||
// In the input as output scenario.
|
||||
// If the input node is the weight of graph, its output device tensor does not need to be replaced.
|
||||
if (output_node->isa<Parameter>() && common::AnfAlgo::IsParameterWeight(output_node->cast<ParameterPtr>())) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -950,6 +950,33 @@ class Validator:
|
|||
raise TypeError(f"For COOTensor, `indices` must have int16 or int32 or int64 data type, but got " \
|
||||
f"{indices_dtype}.")
|
||||
|
||||
@staticmethod
|
||||
def check_dynamic_shape(dyn_inputs, actual_inputs):
|
||||
"""Check the consistency of dynamic shape tensors and actual input tensors."""
|
||||
dyn_inputs_size = len(dyn_inputs)
|
||||
actual_inputs_size = len(actual_inputs)
|
||||
if dyn_inputs_size != actual_inputs_size:
|
||||
raise ValueError(f"The number of actual input tensors: {actual_inputs_size} is not equal to the number of "
|
||||
f"dynamic shape tensors: {dyn_inputs_size}.")
|
||||
for i in range(dyn_inputs_size):
|
||||
if dyn_inputs[i].dtype is not actual_inputs[i].dtype:
|
||||
raise TypeError(f"The data type of index `{i}` args in actual input tensors should be "
|
||||
f"`{dyn_inputs[i].dtype}`, but got `{actual_inputs[i].dtype}`.")
|
||||
if len(dyn_inputs[i].shape) != len(actual_inputs[i].shape):
|
||||
raise ValueError(f"The dimension of index `{i}` args in actual input tensors should be "
|
||||
f"`{len(dyn_inputs[i].shape)}`, but got `{len(actual_inputs[i].shape)}`.")
|
||||
check_dyn_shape_value_equal(i, dyn_inputs[i].shape, actual_inputs[i].shape)
|
||||
return True
|
||||
|
||||
|
||||
def check_dyn_shape_value_equal(index, dyn_shape, actual_shape):
|
||||
"""Check the consistency of dynamic shape and actual input shape."""
|
||||
for i, x in enumerate(dyn_shape):
|
||||
if x not in (-1, actual_shape[i]):
|
||||
raise ValueError(f"The {i}th value in shape of index `{index}` args should be `{x}`, but got "
|
||||
f"`{actual_shape[i]}`.")
|
||||
return True
|
||||
|
||||
|
||||
def check_input_format(input_param):
|
||||
"""Judge input format."""
|
||||
|
|
|
@ -25,7 +25,7 @@ import importlib
|
|||
from collections import OrderedDict
|
||||
from functools import wraps
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from mindspore._extends.remote import kernel_build_server
|
||||
|
@ -249,7 +249,6 @@ class _MindsporeFunctionExecutor:
|
|||
self._graph_executor.updata_param_node_default_input(phase, new_param)
|
||||
obj.load_parameter_slice(None)
|
||||
|
||||
|
||||
if _pynative_executor.get_optimizer():
|
||||
params = obj.trainable_params()
|
||||
opt_params = _pynative_executor.get_optimizer().trainable_params()
|
||||
|
@ -282,17 +281,8 @@ class _MindsporeFunctionExecutor:
|
|||
logger.warning(f"For 'Cell', it's not support hook function when using ms_function. If you want to "
|
||||
f"use hook function, please use context.set_context to set pynative mode and remove "
|
||||
f"`ms_function`.")
|
||||
# Verify the signature for both function and method
|
||||
if self.input_signature is not None:
|
||||
signatures = []
|
||||
for sig_spec in self.input_signature:
|
||||
if not isinstance(sig_spec, MetaTensor):
|
||||
raise TypeError("Input_signature is not MetaTensor")
|
||||
signatures.append(sig_spec)
|
||||
is_valid_input = verify_inputs_signature(signatures, args_list)
|
||||
if not is_valid_input:
|
||||
raise ValueError("Inputs is incompatible with input signature!")
|
||||
|
||||
# Chose dynamic shape tensors or actual input tensors as compile args.
|
||||
compile_args = self._generate_compile_args(args_list)
|
||||
generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + \
|
||||
str(self.fn.__code__.co_firstlineno) + '.' + str(id(self.fn))
|
||||
if _pynative_executor.grad_flag():
|
||||
|
@ -314,7 +304,7 @@ class _MindsporeFunctionExecutor:
|
|||
self.enable_tuple_broaden = self.obj.enable_tuple_broaden
|
||||
|
||||
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
|
||||
key = self._graph_executor.generate_arguments_key(args_list, self.enable_tuple_broaden)
|
||||
key = self._graph_executor.generate_arguments_key(compile_args, self.enable_tuple_broaden)
|
||||
phase = generate_name + '.' + str(key)
|
||||
if phase in ms_compile_cache:
|
||||
return phase
|
||||
|
@ -323,10 +313,10 @@ class _MindsporeFunctionExecutor:
|
|||
self._set_compile_cache_dep_files()
|
||||
|
||||
if self.obj is None:
|
||||
is_compile = self._graph_executor.compile(self.fn, args_list, phase, True)
|
||||
is_compile = self._graph_executor.compile(self.fn, compile_args, phase, True)
|
||||
else:
|
||||
self._graph_executor.set_weights_values(self.obj.parameters_dict())
|
||||
is_compile = self._graph_executor.compile(self.obj, args_list, phase, True)
|
||||
is_compile = self._graph_executor.compile(self.obj, compile_args, phase, True)
|
||||
|
||||
if is_pynative_parallel():
|
||||
self._parallel_process_for_ms_function(phase)
|
||||
|
@ -374,6 +364,43 @@ class _MindsporeFunctionExecutor:
|
|||
|
||||
return output
|
||||
|
||||
def _generate_compile_args(self, args_list):
|
||||
"""Chose dynamic shape tensors or actual input tensors as compile args."""
|
||||
compile_args = args_list
|
||||
# Case: The `set_inputs()` of Cell object has been set, using these dynamic shape args as compile args.
|
||||
if isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs():
|
||||
compile_args = self.obj.get_inputs()
|
||||
for args in compile_args:
|
||||
if not isinstance(args, MsTensor):
|
||||
raise TypeError(f"The args in `set_inputs()` of Cell object should be a Tensor, "
|
||||
f"but got {type(args)}.")
|
||||
Validator.check_dynamic_shape(compile_args, args_list)
|
||||
# Case: The dynamic shape tensors have been assigned to `input_signature`, they are preferred as compile args.
|
||||
if self.input_signature is not None:
|
||||
if not isinstance(self.input_signature, (tuple, list)):
|
||||
self.input_signature = (self.input_signature,)
|
||||
self.input_signature = list(self.input_signature)
|
||||
dyn_shape = False
|
||||
for sig_args in self.input_signature:
|
||||
if not isinstance(sig_args, (MetaTensor, MsTensor)):
|
||||
raise TypeError(f"The args in `input_signature` of `ms_function` should be a Tensor, "
|
||||
f"but got {type(sig_args)}.")
|
||||
if -1 in sig_args.shape:
|
||||
dyn_shape = True
|
||||
if not dyn_shape:
|
||||
if not verify_inputs_signature(self.input_signature, args_list):
|
||||
raise ValueError("The input args is incompatible with the args in `input_signature`!")
|
||||
else:
|
||||
# Checkout whether the `sens` has been added to args_list.
|
||||
if len(self.input_signature) == len(args_list) - 1:
|
||||
logger.warning(f"The number of actual input args `{len(args_list)}` is one more than the number "
|
||||
f"of dynamic shape args `{len(self.input_signature)}`. The last actual args may be "
|
||||
f" `sens` and added to compile args.")
|
||||
self.input_signature.append(args_list[-1])
|
||||
Validator.check_dynamic_shape(self.input_signature, args_list)
|
||||
compile_args = self.input_signature
|
||||
return tuple(compile_args)
|
||||
|
||||
|
||||
def ms_function(fn=None, obj=None, input_signature=None):
|
||||
"""
|
||||
|
@ -932,9 +959,6 @@ class _CellGraphExecutor:
|
|||
self._graph_executor.set_queue_name(queue_name)
|
||||
return True
|
||||
|
||||
def _build_data_graph(self, obj, phase):
|
||||
self._graph_executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict())
|
||||
|
||||
def set_queue_name(self, queue_name):
|
||||
"""
|
||||
while a mode use shared dataset with others, need set queue_name which saved in data_set
|
||||
|
@ -952,6 +976,9 @@ class _CellGraphExecutor:
|
|||
else:
|
||||
_set_dataset_mode_config('normal')
|
||||
|
||||
def _build_data_graph(self, obj, phase):
|
||||
self._graph_executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict())
|
||||
|
||||
@staticmethod
|
||||
def _use_vm_mode():
|
||||
enable_ge = context.get_context("enable_ge")
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
"""Basic composite operations."""
|
||||
from functools import partial
|
||||
from types import FunctionType
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context
|
||||
from ..._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \
|
||||
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, \
|
||||
|
@ -368,12 +368,15 @@ class GradOperation(GradOperation_):
|
|||
# In pure PYNATIVE_MODE the out layer after_grad just used to set pynative flag for inner GradOperation.
|
||||
# In PYNATIVE_MODE calling Grad from ms_function, use the out layer after_grad do grad in GRAPH_MODE.
|
||||
if context.get_context("mode") == context.GRAPH_MODE:
|
||||
dynamic_shape_inputs = None
|
||||
if isinstance(fn, ms.nn.Cell):
|
||||
dynamic_shape_inputs = fn.get_inputs()
|
||||
if self.get_by_list:
|
||||
@ms_function
|
||||
@ms_function(input_signature=dynamic_shape_inputs)
|
||||
def after_grad(*args):
|
||||
return grad_(fn, weights)(*args)
|
||||
else:
|
||||
@ms_function
|
||||
@ms_function(input_signature=dynamic_shape_inputs)
|
||||
def after_grad(*args):
|
||||
return grad_(fn)(*args)
|
||||
elif self.pynative_:
|
||||
|
@ -461,17 +464,20 @@ class _Grad(GradOperation_):
|
|||
# In pure PYNATIVE_MODE the out layer after_grad just used to set pynative flag for inner GradOperation.
|
||||
# In PYNATIVE_MODE calling Grad from ms_function, use the out layer after_grad do grad in GRAPH_MODE.
|
||||
if context.get_context("mode") == context.GRAPH_MODE:
|
||||
dynamic_shape_inputs = None
|
||||
if isinstance(fn, ms.nn.Cell):
|
||||
dynamic_shape_inputs = fn.get_inputs()
|
||||
if self.get_by_position:
|
||||
@ms_function
|
||||
@ms_function(input_signature=dynamic_shape_inputs)
|
||||
def after_grad(*args):
|
||||
return grad_(fn, weights, grad_position)(*args)
|
||||
else:
|
||||
if self.get_by_list:
|
||||
@ms_function
|
||||
@ms_function(input_signature=dynamic_shape_inputs)
|
||||
def after_grad(*args):
|
||||
return grad_(fn, weights)(*args)
|
||||
else:
|
||||
@ms_function
|
||||
@ms_function(input_signature=dynamic_shape_inputs)
|
||||
def after_grad(*args):
|
||||
return grad_(fn)(*args)
|
||||
elif self.pynative_:
|
||||
|
|
Loading…
Reference in New Issue