jtorch compatible

This commit is contained in:
Dun Liang 2023-01-06 20:48:04 +08:00
parent 437a720500
commit 6fbffb7c62
3 changed files with 4 additions and 4 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.6.9'
__version__ = '1.3.6.10'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -291,7 +291,7 @@ def median(x,dim=None,keepdim=False, keepdims=False):
if dim is None:
x = x.reshape(-1)
dim=0
_,x = x.argsort(dim)
_,x = jt.argsort(x, dim)
slices = [slice(None) for i in range(dim-1)]
k = (x.shape[dim]-1)//2
if keepdim:

View File

@ -498,9 +498,9 @@ def softmax(x, dim=None, log=False):
return code_softmax.softmax_v1(x, log)
if dim is None: dim = ()
if log:
a = x-x.max(dim, keepdims=True)
a = x - jt.max(x, dim, keepdims=True)
return a - a.exp().sum(dim, keepdims=True).log()
x = (x-x.max(dim, keepdims=True)).exp()
x = (x - jt.max(x, dim, keepdims=True)).exp()
return x / x.sum(dim, keepdims=True)
jt.Var.softmax = softmax