forked from jittor/jittor
jtorch compatible
This commit is contained in:
parent
437a720500
commit
6fbffb7c62
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.6.9'
|
||||
__version__ = '1.3.6.10'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -291,7 +291,7 @@ def median(x,dim=None,keepdim=False, keepdims=False):
|
|||
if dim is None:
|
||||
x = x.reshape(-1)
|
||||
dim=0
|
||||
_,x = x.argsort(dim)
|
||||
_,x = jt.argsort(x, dim)
|
||||
slices = [slice(None) for i in range(dim-1)]
|
||||
k = (x.shape[dim]-1)//2
|
||||
if keepdim:
|
||||
|
|
|
@ -498,9 +498,9 @@ def softmax(x, dim=None, log=False):
|
|||
return code_softmax.softmax_v1(x, log)
|
||||
if dim is None: dim = ()
|
||||
if log:
|
||||
a = x-x.max(dim, keepdims=True)
|
||||
a = x - jt.max(x, dim, keepdims=True)
|
||||
return a - a.exp().sum(dim, keepdims=True).log()
|
||||
x = (x-x.max(dim, keepdims=True)).exp()
|
||||
x = (x - jt.max(x, dim, keepdims=True)).exp()
|
||||
return x / x.sum(dim, keepdims=True)
|
||||
jt.Var.softmax = softmax
|
||||
|
||||
|
|
Loading…
Reference in New Issue