remove usage of obj in ms_function

This commit is contained in:
wangjun 2022-05-05 10:25:44 +08:00
parent dbfa01dfea
commit 0d3f2c6f3c
3 changed files with 14 additions and 13 deletions

View File

@ -1,7 +1,7 @@
mindspore.ms_function
=====================
.. py:function:: mindspore.ms_function(fn=None, obj=None, input_signature=None, hash_args=None)
.. py:function:: mindspore.ms_function(fn=None, input_signature=None, hash_args=None)
将Python函数编译为一张可调用的MindSpore图。
@ -10,7 +10,6 @@ mindspore.ms_function
**参数:**
- **fn** (Function) - 要编译成图的Python函数。默认值None。
- **obj** (Object) - 用于区分编译后函数的Python对象。默认值None。
- **input_signature** (Tensor) - 用于表示输入参数的Tensor。Tensor的shape和dtype将作为函数的输入shape和dtype。默认值None。
- **hash_args** (Union[Object, List or Tuple of Objects]) - `fn` 里面用到的自由变量,比如外部函数或类对象,再次调用时若 `hash_args` 出现变化会触发重新编译。默认值None。

View File

@ -287,6 +287,8 @@ class _MindsporeFunctionExecutor:
str(self.fn.__code__.co_firstlineno) + '.' + str(id(self.fn))
if _pynative_executor.grad_flag():
generate_name = generate_name + ".grad"
if is_pynative_parallel():
generate_name = generate_name[:generate_name.rfind(str(id(self.fn)))] + str(id(self.shard_parent_obj))
self.fn.__parse_method__ = method_name
# Add key with obj
@ -427,7 +429,7 @@ def _get_ms_function_hash(hash_input):
return _get_obj_id(hash_input)
def ms_function(fn=None, obj=None, input_signature=None, hash_args=None):
def ms_function(fn=None, input_signature=None, hash_args=None):
"""
Create a callable MindSpore graph from a Python function.
@ -435,7 +437,6 @@ def ms_function(fn=None, obj=None, input_signature=None, hash_args=None):
Args:
fn (Function): The Python function that will be run as a graph. Default: None.
obj (Object): The Python object is used to distinguish the compiled function. Default: None.
input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor
will be supplied to this function. If input_signature is specified, each input to `fn` must be a `Tensor`.
And the input parameters of `fn` cannot accept `**kwargs`. The shape and dtype of actual inputs should
@ -509,14 +510,12 @@ def ms_function(fn=None, obj=None, input_signature=None, hash_args=None):
@wraps(func)
def staging_specialize(*args):
if obj is not None:
logger.warning("Obj is no longer in use, and the function's own object has been used to \
distinguish whether it has been compiled.")
process_obj = None
if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__):
process_obj = args[0]
if process_obj is None and is_pynative_parallel():
process_obj = obj
if is_pynative_parallel():
process_obj = args[0]
args = args[1:]
out = _MindsporeFunctionExecutor(func, hash_obj, input_signature, process_obj)(*args)
return out

View File

@ -837,11 +837,14 @@ class Shard(Shard_):
return self.shard_fn
shard_ = Shard()
@ms_function(obj=fn)
def after_shard(*args):
return shard_(fn, in_strategy, out_strategy, device, level)(*args)
def shard_fn(*args):
args = (fn,) + args
@ms_function(hash_args=fn)
def after_shard(*args):
return shard_(fn, in_strategy, out_strategy, device, level)(*args)
return after_shard(*args)
self.shard_fn = after_shard
self.shard_fn = shard_fn
self.fn = fn
self.in_strategy = in_strategy
self.out_strategy = out_strategy