forked from jittor/jittor
fix wrong transpose.
This commit is contained in:
parent
5eaccf538d
commit
b21ee40360
|
@ -113,19 +113,20 @@ class UnpicklerWrapper(pickle.Unpickler): # type: ignore[name-defined]
|
|||
return super().find_class(mod_name, name)
|
||||
|
||||
class ArrayWrapper:
|
||||
def __init__(self, storage, size=None, requires_grad=None):
|
||||
def __init__(self, storage, stride=None, size=None, requires_grad=None):
|
||||
self.requires_grad = requires_grad
|
||||
self.size = size
|
||||
self.storage = storage
|
||||
|
||||
self.stride = stride
|
||||
|
||||
def __str__(self):
|
||||
return self.storage.__str__()
|
||||
|
||||
def jittor_rebuild_direct(storage, storage_offset, size, stride, requires_grad, backward_hooks):
|
||||
if len(size) == 0:
|
||||
return ArrayWrapper(storage, size=size)
|
||||
return ArrayWrapper(storage, stride=stride, size=size)
|
||||
storage.reshape(size)
|
||||
return ArrayWrapper(storage, size=size)
|
||||
return ArrayWrapper(storage, stride=stride, size=size)
|
||||
|
||||
def jittor_rebuild_var_direct(data, requires_grad, backward_hooks):
|
||||
v = ArrayWrapper(storage, requires_grad=requires_grad)
|
||||
|
@ -255,7 +256,14 @@ def load_pytorch(fn_name):
|
|||
shape = params.size
|
||||
result[key] = jt.array(params.storage)
|
||||
if shape is not None and len(shape) > 0:
|
||||
result[key] = result[key].reshape(shape)
|
||||
if len(params.stride) > 1:
|
||||
eval_list = []
|
||||
for idx in range(len(params.stride)):
|
||||
eval_list.append(f"@e0({idx}) * i{idx}")
|
||||
evals = "+".join(eval_list)
|
||||
result[key] = result[key].reindex(params.size, [evals], extras=[jt.array(params.stride)])
|
||||
else:
|
||||
result[key] = result[key].reshape(shape)
|
||||
if requires_grad is not None:
|
||||
result[key].requires_grad = requires_grad
|
||||
return result
|
||||
|
|
Loading…
Reference in New Issue