remove usage of obj in ms_function
This commit is contained in:
parent
dbfa01dfea
commit
0d3f2c6f3c
|
@ -1,7 +1,7 @@
|
|||
mindspore.ms_function
|
||||
=====================
|
||||
|
||||
.. py:function:: mindspore.ms_function(fn=None, obj=None, input_signature=None, hash_args=None)
|
||||
.. py:function:: mindspore.ms_function(fn=None, input_signature=None, hash_args=None)
|
||||
|
||||
将Python函数编译为一张可调用的MindSpore图。
|
||||
|
||||
|
@ -10,7 +10,6 @@ mindspore.ms_function
|
|||
**参数:**
|
||||
|
||||
- **fn** (Function) - 要编译成图的Python函数。默认值:None。
|
||||
- **obj** (Object) - 用于区分编译后函数的Python对象。默认值:None。
|
||||
- **input_signature** (Tensor) - 用于表示输入参数的Tensor。Tensor的shape和dtype将作为函数的输入shape和dtype。默认值:None。
|
||||
- **hash_args** (Union[Object, List or Tuple of Objects]) - `fn` 里面用到的自由变量,比如外部函数或类对象,再次调用时若 `hash_args` 出现变化会触发重新编译。默认值:None。
|
||||
|
||||
|
|
|
@ -287,6 +287,8 @@ class _MindsporeFunctionExecutor:
|
|||
str(self.fn.__code__.co_firstlineno) + '.' + str(id(self.fn))
|
||||
if _pynative_executor.grad_flag():
|
||||
generate_name = generate_name + ".grad"
|
||||
if is_pynative_parallel():
|
||||
generate_name = generate_name[:generate_name.rfind(str(id(self.fn)))] + str(id(self.shard_parent_obj))
|
||||
self.fn.__parse_method__ = method_name
|
||||
|
||||
# Add key with obj
|
||||
|
@ -427,7 +429,7 @@ def _get_ms_function_hash(hash_input):
|
|||
return _get_obj_id(hash_input)
|
||||
|
||||
|
||||
def ms_function(fn=None, obj=None, input_signature=None, hash_args=None):
|
||||
def ms_function(fn=None, input_signature=None, hash_args=None):
|
||||
"""
|
||||
Create a callable MindSpore graph from a Python function.
|
||||
|
||||
|
@ -435,7 +437,6 @@ def ms_function(fn=None, obj=None, input_signature=None, hash_args=None):
|
|||
|
||||
Args:
|
||||
fn (Function): The Python function that will be run as a graph. Default: None.
|
||||
obj (Object): The Python object is used to distinguish the compiled function. Default: None.
|
||||
input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor
|
||||
will be supplied to this function. If input_signature is specified, each input to `fn` must be a `Tensor`.
|
||||
And the input parameters of `fn` cannot accept `**kwargs`. The shape and dtype of actual inputs should
|
||||
|
@ -509,14 +510,12 @@ def ms_function(fn=None, obj=None, input_signature=None, hash_args=None):
|
|||
|
||||
@wraps(func)
|
||||
def staging_specialize(*args):
|
||||
if obj is not None:
|
||||
logger.warning("Obj is no longer in use, and the function's own object has been used to \
|
||||
distinguish whether it has been compiled.")
|
||||
process_obj = None
|
||||
if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__):
|
||||
process_obj = args[0]
|
||||
if process_obj is None and is_pynative_parallel():
|
||||
process_obj = obj
|
||||
if is_pynative_parallel():
|
||||
process_obj = args[0]
|
||||
args = args[1:]
|
||||
out = _MindsporeFunctionExecutor(func, hash_obj, input_signature, process_obj)(*args)
|
||||
return out
|
||||
|
||||
|
|
|
@ -837,11 +837,14 @@ class Shard(Shard_):
|
|||
return self.shard_fn
|
||||
shard_ = Shard()
|
||||
|
||||
@ms_function(obj=fn)
|
||||
def after_shard(*args):
|
||||
return shard_(fn, in_strategy, out_strategy, device, level)(*args)
|
||||
def shard_fn(*args):
|
||||
args = (fn,) + args
|
||||
@ms_function(hash_args=fn)
|
||||
def after_shard(*args):
|
||||
return shard_(fn, in_strategy, out_strategy, device, level)(*args)
|
||||
return after_shard(*args)
|
||||
|
||||
self.shard_fn = after_shard
|
||||
self.shard_fn = shard_fn
|
||||
self.fn = fn
|
||||
self.in_strategy = in_strategy
|
||||
self.out_strategy = out_strategy
|
||||
|
|
Loading…
Reference in New Issue