This commit is contained in:
lianliguang 2021-07-08 16:15:46 +08:00
parent 7d47b12e3b
commit e5b0288076
1 changed files with 41 additions and 6 deletions

View File

@ -94,11 +94,19 @@ class Primitive(Primitive_):
def add_prim_attr(self, name, value):
"""
Adds primitive attribute.
Add primitive attribute.
Args:
name (str): Attribute Name.
value (Any): Attribute value.
Example:
>>>import mindspore.ops as P
>>>a = P.Add()
>>>a = a.add_prim_attr("attr",1)
>>>out = a.attrs["attr"]
>>>print(out)
1
"""
self.__dict__[name] = value
self.attrs[name] = value
@ -107,10 +115,17 @@ class Primitive(Primitive_):
def del_prim_attr(self, name):
"""
Del primitive attribute.
Delete primitive attribute.
Args:
name (str): Attribute Name.
Example:
>>>import mindspore.ops as P
>>>a = P.Add()
>>>a = a.add_prim_attr("attr",1)
>>>a = a.del_prim_attr("attr")
>>>a.attrs
{'input_names': ['x', 'y'], 'output_names' : ['output']}
"""
if name in self.__dict__ and name in self.attrs:
del self.__dict__[name]
@ -127,7 +142,7 @@ class Primitive(Primitive_):
In other parallel modes, please set it to be 0.
Args:
stage (int): The stage id for the current operation
stage (int): The stage id for the current operation.
"""
self.add_prim_attr("stage", stage)
return self
@ -165,6 +180,12 @@ class Primitive(Primitive_):
Args:
instance_name (str): Instance name of primitive operator set by user.
Example:
>>>import mindspore.ops as P
>>>a = P.Add()
>>>a.set_prim_instance_name("add")
>>>a.instance_name
'add'
"""
self.set_instance_name(instance_name)
self.instance_name = instance_name
@ -188,8 +209,8 @@ class Primitive(Primitive_):
Returns:
A tuple consisting of two elements.
The first element means if the primitive can be calculated in compiling stage
the second element is calculated result .
The first element means if the primitive can be calculated in compiling stage,
the second element is calculated result.
Examples:
>>> class AddN(Primitive):
@ -234,11 +255,19 @@ class Primitive(Primitive_):
def init_prim_io_names(self, inputs, outputs):
"""
Initializes the name of inputs and outputs of Tensor or attributes.
Initialize the name of inputs and outputs of Tensor or attributes.
Args:
inputs (list[str]): list of inputs names.
outputs (list[str]): list of outputs names.
Example:
>>>import mindspore.ops as P
>>>a = P.Add()
>>>a.init_prim_io_names(["x","y"],["sum"])
>>>a.input_names
['x','y']
>>>a.output_names
['sum']
"""
# for checking para names with kernel implementation
self.add_prim_attr("input_names", inputs)
@ -263,6 +292,12 @@ class Primitive(Primitive_):
Args:
mode (bool): Specifies whether the primitive is recomputed. Default: True.
Example:
>>>import mindspore.ops as P
>>>a = P.Add()
>>>a = a.recompute()
>>>a.recompute
True
"""
if context.get_context("mode") == context.PYNATIVE_MODE:
raise TypeError("Recompute is not supported in pynative mode currently.")