fix wrong transpose.

This commit is contained in:
Exusial 2022-10-03 18:38:14 +08:00
parent 5eaccf538d
commit b21ee40360
1 changed files with 13 additions and 5 deletions

View File

@ -113,19 +113,20 @@ class UnpicklerWrapper(pickle.Unpickler): # type: ignore[name-defined]
return super().find_class(mod_name, name)
class ArrayWrapper:
def __init__(self, storage, size=None, requires_grad=None):
def __init__(self, storage, stride=None, size=None, requires_grad=None):
self.requires_grad = requires_grad
self.size = size
self.storage = storage
self.stride = stride
def __str__(self):
return self.storage.__str__()
def jittor_rebuild_direct(storage, storage_offset, size, stride, requires_grad, backward_hooks):
if len(size) == 0:
return ArrayWrapper(storage, size=size)
return ArrayWrapper(storage, stride=stride, size=size)
storage.reshape(size)
return ArrayWrapper(storage, size=size)
return ArrayWrapper(storage, stride=stride, size=size)
def jittor_rebuild_var_direct(data, requires_grad, backward_hooks):
v = ArrayWrapper(storage, requires_grad=requires_grad)
@ -255,7 +256,14 @@ def load_pytorch(fn_name):
shape = params.size
result[key] = jt.array(params.storage)
if shape is not None and len(shape) > 0:
result[key] = result[key].reshape(shape)
if len(params.stride) > 1:
eval_list = []
for idx in range(len(params.stride)):
eval_list.append(f"@e0({idx}) * i{idx}")
evals = "+".join(eval_list)
result[key] = result[key].reindex(params.size, [evals], extras=[jt.array(params.stride)])
else:
result[key] = result[key].reshape(shape)
if requires_grad is not None:
result[key].requires_grad = requires_grad
return result