forked from mindspore-Ecosystem/mindspore
support python func print and != for list with none
This commit is contained in:
parent
679dbd27b3
commit
7c233a57fa
|
@ -114,6 +114,7 @@ convert_object_map = {
|
|||
T.map: C.HyperMap(),
|
||||
T.partial: F.partial,
|
||||
T.zip: C.zip_operation,
|
||||
T.print: F.print_,
|
||||
|
||||
# custom define operation
|
||||
T.iter: M.ms_iter,
|
||||
|
|
|
@ -27,7 +27,7 @@ from operator import ( # noqa
|
|||
|
||||
# support system function call
|
||||
from builtins import ( # noqa
|
||||
bool, getattr, setattr, len, iter, next, pow, range, map, zip
|
||||
bool, getattr, setattr, len, iter, next, pow, range, map, zip, print
|
||||
)
|
||||
|
||||
# support functools
|
||||
|
@ -44,7 +44,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt',
|
|||
'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains',
|
||||
'matmul', 'getitem', 'setitem',
|
||||
'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
|
||||
'partial',
|
||||
'partial', 'print',
|
||||
'exp', 'log', 'sin', 'cos', 'tan']
|
||||
|
||||
|
||||
|
|
|
@ -132,7 +132,7 @@ def _none_not_equal_scalar(x, y):
|
|||
|
||||
|
||||
@not_equal.register("Tuple", "Tuple")
|
||||
def _euqal_tuple(x, y):
|
||||
def _not_euqal_tuple(x, y):
|
||||
"""
|
||||
Determine if two tuples are not equal by element.
|
||||
|
||||
|
@ -147,7 +147,7 @@ def _euqal_tuple(x, y):
|
|||
|
||||
|
||||
@not_equal.register("List", "List")
|
||||
def _euqal_list(x, y):
|
||||
def _not_euqal_list(x, y):
|
||||
"""
|
||||
Determine if two lists are not equal by element.
|
||||
|
||||
|
@ -162,7 +162,7 @@ def _euqal_list(x, y):
|
|||
|
||||
|
||||
@not_equal.register("Tuple", "None")
|
||||
def _tuple_euqal_none(x, y):
|
||||
def _tuple_not_euqal_none(x, y):
|
||||
"""
|
||||
Determine if tuple element not equals none element.
|
||||
|
||||
|
@ -190,6 +190,7 @@ def _none_not_equal_tuple(x, y):
|
|||
"""
|
||||
return True
|
||||
|
||||
|
||||
@not_equal.register("Tensor", "Number")
|
||||
@not_equal.register("Number", "Tensor")
|
||||
@not_equal.register("Tensor", "Tensor")
|
||||
|
@ -235,3 +236,33 @@ def _none_not_equal_tensor(x, y):
|
|||
bool, return True.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
@not_equal.register("List", "None")
|
||||
def _list_not_equal_none(x, y):
|
||||
"""
|
||||
Determine if list not equal none.
|
||||
|
||||
Args:
|
||||
x (list): The first input which is a list.
|
||||
y (none): The second input which is none.
|
||||
|
||||
Returns:
|
||||
bool, return true.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
@not_equal.register("None", "List")
|
||||
def _none_not_equal_list(x, y):
|
||||
"""
|
||||
Determine if none not equal list.
|
||||
|
||||
Args:
|
||||
x (none): The first input which is none.
|
||||
y (list): The second input which is a list.
|
||||
|
||||
Returns:
|
||||
bool, return true.
|
||||
"""
|
||||
return True
|
||||
|
|
|
@ -66,7 +66,7 @@ scalar_to_array = P.ScalarToArray()
|
|||
scalar_to_tensor = P.ScalarToTensor()
|
||||
tuple_to_array = P.TupleToArray()
|
||||
scalar_cast = P.ScalarCast()
|
||||
|
||||
print_ = P.Print()
|
||||
|
||||
tuple_setitem = Primitive('tuple_setitem')
|
||||
tuple_getitem = Primitive('tuple_getitem')
|
||||
|
|
|
@ -108,6 +108,7 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer):
|
|||
validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type)
|
||||
return x_type
|
||||
|
||||
|
||||
class ConcatOffset(PrimitiveWithInfer):
|
||||
"""primitive for computing Concat's gradient."""
|
||||
|
||||
|
|
|
@ -160,8 +160,10 @@ def test_ops():
|
|||
ret_floor = p // q + q // p
|
||||
ret = ret_pow + ret_mod + ret_floor
|
||||
if self.int > self.float:
|
||||
if self.str_a + self.str_b == "helloworld":
|
||||
return ret
|
||||
if [1, 2, 3] != None:
|
||||
if self.str_a + self.str_b == "helloworld":
|
||||
print("hello world")
|
||||
return ret
|
||||
return x
|
||||
|
||||
net = OpsNet(9, 2)
|
||||
|
|
|
@ -151,8 +151,6 @@ def vm_impl_max_pool_grad_with_argmax(self):
|
|||
"""Generate vm_impl function for MaxPoolGradWithArgmax"""
|
||||
|
||||
def vm_impl(x, dout, argmax):
|
||||
print("buxue")
|
||||
print(argmax)
|
||||
x = x.asnumpy()
|
||||
dout = dout.asnumpy()
|
||||
arg_max = argmax.asnumpy()
|
||||
|
|
Loading…
Reference in New Issue