remove usage of obj in ms_function
This commit is contained in:
parent
dbfa01dfea
commit
0d3f2c6f3c
|
@ -1,7 +1,7 @@
|
||||||
mindspore.ms_function
|
mindspore.ms_function
|
||||||
=====================
|
=====================
|
||||||
|
|
||||||
.. py:function:: mindspore.ms_function(fn=None, obj=None, input_signature=None, hash_args=None)
|
.. py:function:: mindspore.ms_function(fn=None, input_signature=None, hash_args=None)
|
||||||
|
|
||||||
将Python函数编译为一张可调用的MindSpore图。
|
将Python函数编译为一张可调用的MindSpore图。
|
||||||
|
|
||||||
|
@ -10,7 +10,6 @@ mindspore.ms_function
|
||||||
**参数:**
|
**参数:**
|
||||||
|
|
||||||
- **fn** (Function) - 要编译成图的Python函数。默认值:None。
|
- **fn** (Function) - 要编译成图的Python函数。默认值:None。
|
||||||
- **obj** (Object) - 用于区分编译后函数的Python对象。默认值:None。
|
|
||||||
- **input_signature** (Tensor) - 用于表示输入参数的Tensor。Tensor的shape和dtype将作为函数的输入shape和dtype。默认值:None。
|
- **input_signature** (Tensor) - 用于表示输入参数的Tensor。Tensor的shape和dtype将作为函数的输入shape和dtype。默认值:None。
|
||||||
- **hash_args** (Union[Object, List or Tuple of Objects]) - `fn` 里面用到的自由变量,比如外部函数或类对象,再次调用时若 `hash_args` 出现变化会触发重新编译。默认值:None。
|
- **hash_args** (Union[Object, List or Tuple of Objects]) - `fn` 里面用到的自由变量,比如外部函数或类对象,再次调用时若 `hash_args` 出现变化会触发重新编译。默认值:None。
|
||||||
|
|
||||||
|
|
|
@ -287,6 +287,8 @@ class _MindsporeFunctionExecutor:
|
||||||
str(self.fn.__code__.co_firstlineno) + '.' + str(id(self.fn))
|
str(self.fn.__code__.co_firstlineno) + '.' + str(id(self.fn))
|
||||||
if _pynative_executor.grad_flag():
|
if _pynative_executor.grad_flag():
|
||||||
generate_name = generate_name + ".grad"
|
generate_name = generate_name + ".grad"
|
||||||
|
if is_pynative_parallel():
|
||||||
|
generate_name = generate_name[:generate_name.rfind(str(id(self.fn)))] + str(id(self.shard_parent_obj))
|
||||||
self.fn.__parse_method__ = method_name
|
self.fn.__parse_method__ = method_name
|
||||||
|
|
||||||
# Add key with obj
|
# Add key with obj
|
||||||
|
@ -427,7 +429,7 @@ def _get_ms_function_hash(hash_input):
|
||||||
return _get_obj_id(hash_input)
|
return _get_obj_id(hash_input)
|
||||||
|
|
||||||
|
|
||||||
def ms_function(fn=None, obj=None, input_signature=None, hash_args=None):
|
def ms_function(fn=None, input_signature=None, hash_args=None):
|
||||||
"""
|
"""
|
||||||
Create a callable MindSpore graph from a Python function.
|
Create a callable MindSpore graph from a Python function.
|
||||||
|
|
||||||
|
@ -435,7 +437,6 @@ def ms_function(fn=None, obj=None, input_signature=None, hash_args=None):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fn (Function): The Python function that will be run as a graph. Default: None.
|
fn (Function): The Python function that will be run as a graph. Default: None.
|
||||||
obj (Object): The Python object is used to distinguish the compiled function. Default: None.
|
|
||||||
input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor
|
input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor
|
||||||
will be supplied to this function. If input_signature is specified, each input to `fn` must be a `Tensor`.
|
will be supplied to this function. If input_signature is specified, each input to `fn` must be a `Tensor`.
|
||||||
And the input parameters of `fn` cannot accept `**kwargs`. The shape and dtype of actual inputs should
|
And the input parameters of `fn` cannot accept `**kwargs`. The shape and dtype of actual inputs should
|
||||||
|
@ -509,14 +510,12 @@ def ms_function(fn=None, obj=None, input_signature=None, hash_args=None):
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def staging_specialize(*args):
|
def staging_specialize(*args):
|
||||||
if obj is not None:
|
|
||||||
logger.warning("Obj is no longer in use, and the function's own object has been used to \
|
|
||||||
distinguish whether it has been compiled.")
|
|
||||||
process_obj = None
|
process_obj = None
|
||||||
if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__):
|
if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__):
|
||||||
process_obj = args[0]
|
process_obj = args[0]
|
||||||
if process_obj is None and is_pynative_parallel():
|
if is_pynative_parallel():
|
||||||
process_obj = obj
|
process_obj = args[0]
|
||||||
|
args = args[1:]
|
||||||
out = _MindsporeFunctionExecutor(func, hash_obj, input_signature, process_obj)(*args)
|
out = _MindsporeFunctionExecutor(func, hash_obj, input_signature, process_obj)(*args)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
|
@ -837,11 +837,14 @@ class Shard(Shard_):
|
||||||
return self.shard_fn
|
return self.shard_fn
|
||||||
shard_ = Shard()
|
shard_ = Shard()
|
||||||
|
|
||||||
@ms_function(obj=fn)
|
def shard_fn(*args):
|
||||||
def after_shard(*args):
|
args = (fn,) + args
|
||||||
return shard_(fn, in_strategy, out_strategy, device, level)(*args)
|
@ms_function(hash_args=fn)
|
||||||
|
def after_shard(*args):
|
||||||
|
return shard_(fn, in_strategy, out_strategy, device, level)(*args)
|
||||||
|
return after_shard(*args)
|
||||||
|
|
||||||
self.shard_fn = after_shard
|
self.shard_fn = shard_fn
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
self.in_strategy = in_strategy
|
self.in_strategy = in_strategy
|
||||||
self.out_strategy = out_strategy
|
self.out_strategy = out_strategy
|
||||||
|
|
Loading…
Reference in New Issue