support python func print and != for list with none

This commit is contained in:
buxue 2020-04-20 17:40:46 +08:00
parent 679dbd27b3
commit 7c233a57fa
7 changed files with 43 additions and 10 deletions

View File

@ -114,6 +114,7 @@ convert_object_map = {
T.map: C.HyperMap(), T.map: C.HyperMap(),
T.partial: F.partial, T.partial: F.partial,
T.zip: C.zip_operation, T.zip: C.zip_operation,
T.print: F.print_,
# custom define operation # custom define operation
T.iter: M.ms_iter, T.iter: M.ms_iter,

View File

@ -27,7 +27,7 @@ from operator import ( # noqa
# support system function call # support system function call
from builtins import ( # noqa from builtins import ( # noqa
bool, getattr, setattr, len, iter, next, pow, range, map, zip bool, getattr, setattr, len, iter, next, pow, range, map, zip, print
) )
# support functools # support functools
@ -44,7 +44,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt',
'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains', 'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains',
'matmul', 'getitem', 'setitem', 'matmul', 'getitem', 'setitem',
'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip', 'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
'partial', 'partial', 'print',
'exp', 'log', 'sin', 'cos', 'tan'] 'exp', 'log', 'sin', 'cos', 'tan']

View File

@ -132,7 +132,7 @@ def _none_not_equal_scalar(x, y):
@not_equal.register("Tuple", "Tuple") @not_equal.register("Tuple", "Tuple")
def _euqal_tuple(x, y): def _not_euqal_tuple(x, y):
""" """
Determine if two tuples are not equal by element. Determine if two tuples are not equal by element.
@ -147,7 +147,7 @@ def _euqal_tuple(x, y):
@not_equal.register("List", "List") @not_equal.register("List", "List")
def _euqal_list(x, y): def _not_euqal_list(x, y):
""" """
Determine if two lists are not equal by element. Determine if two lists are not equal by element.
@ -162,7 +162,7 @@ def _euqal_list(x, y):
@not_equal.register("Tuple", "None") @not_equal.register("Tuple", "None")
def _tuple_euqal_none(x, y): def _tuple_not_euqal_none(x, y):
""" """
Determine if tuple element not equals none element. Determine if tuple element not equals none element.
@ -190,6 +190,7 @@ def _none_not_equal_tuple(x, y):
""" """
return True return True
@not_equal.register("Tensor", "Number") @not_equal.register("Tensor", "Number")
@not_equal.register("Number", "Tensor") @not_equal.register("Number", "Tensor")
@not_equal.register("Tensor", "Tensor") @not_equal.register("Tensor", "Tensor")
@ -235,3 +236,33 @@ def _none_not_equal_tensor(x, y):
bool, return True. bool, return True.
""" """
return True return True
@not_equal.register("List", "None")
def _list_not_equal_none(x, y):
"""
Determine if list not equal none.
Args:
x (list): The first input which is a list.
y (none): The second input which is none.
Returns:
bool, return true.
"""
return True
@not_equal.register("None", "List")
def _none_not_equal_list(x, y):
"""
Determine if none not equal list.
Args:
x (none): The first input which is none.
y (list): The second input which is a list.
Returns:
bool, return true.
"""
return True

View File

@ -66,7 +66,7 @@ scalar_to_array = P.ScalarToArray()
scalar_to_tensor = P.ScalarToTensor() scalar_to_tensor = P.ScalarToTensor()
tuple_to_array = P.TupleToArray() tuple_to_array = P.TupleToArray()
scalar_cast = P.ScalarCast() scalar_cast = P.ScalarCast()
print_ = P.Print()
tuple_setitem = Primitive('tuple_setitem') tuple_setitem = Primitive('tuple_setitem')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive('tuple_getitem')

View File

@ -108,6 +108,7 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer):
validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type) validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type)
return x_type return x_type
class ConcatOffset(PrimitiveWithInfer): class ConcatOffset(PrimitiveWithInfer):
"""primitive for computing Concat's gradient.""" """primitive for computing Concat's gradient."""

View File

@ -160,8 +160,10 @@ def test_ops():
ret_floor = p // q + q // p ret_floor = p // q + q // p
ret = ret_pow + ret_mod + ret_floor ret = ret_pow + ret_mod + ret_floor
if self.int > self.float: if self.int > self.float:
if self.str_a + self.str_b == "helloworld": if [1, 2, 3] != None:
return ret if self.str_a + self.str_b == "helloworld":
print("hello world")
return ret
return x return x
net = OpsNet(9, 2) net = OpsNet(9, 2)

View File

@ -151,8 +151,6 @@ def vm_impl_max_pool_grad_with_argmax(self):
"""Generate vm_impl function for MaxPoolGradWithArgmax""" """Generate vm_impl function for MaxPoolGradWithArgmax"""
def vm_impl(x, dout, argmax): def vm_impl(x, dout, argmax):
print("buxue")
print(argmax)
x = x.asnumpy() x = x.asnumpy()
dout = dout.asnumpy() dout = dout.asnumpy()
arg_max = argmax.asnumpy() arg_max = argmax.asnumpy()