diff --git a/python/jittor_utils/load_pytorch.py b/python/jittor_utils/load_pytorch.py index 156dd4ff..14ca7376 100644 --- a/python/jittor_utils/load_pytorch.py +++ b/python/jittor_utils/load_pytorch.py @@ -113,19 +113,20 @@ class UnpicklerWrapper(pickle.Unpickler): # type: ignore[name-defined] return super().find_class(mod_name, name) class ArrayWrapper: - def __init__(self, storage, size=None, requires_grad=None): + def __init__(self, storage, stride=None, size=None, requires_grad=None): self.requires_grad = requires_grad self.size = size self.storage = storage - + self.stride = stride + def __str__(self): return self.storage.__str__() def jittor_rebuild_direct(storage, storage_offset, size, stride, requires_grad, backward_hooks): if len(size) == 0: - return ArrayWrapper(storage, size=size) + return ArrayWrapper(storage, stride=stride, size=size) storage.reshape(size) - return ArrayWrapper(storage, size=size) + return ArrayWrapper(storage, stride=stride, size=size) def jittor_rebuild_var_direct(data, requires_grad, backward_hooks): v = ArrayWrapper(storage, requires_grad=requires_grad) @@ -255,7 +256,14 @@ def load_pytorch(fn_name): shape = params.size result[key] = jt.array(params.storage) if shape is not None and len(shape) > 0: - result[key] = result[key].reshape(shape) + if len(params.stride) > 1: + eval_list = [] + for idx in range(len(params.stride)): + eval_list.append(f"@e0({idx}) * i{idx}") + evals = "+".join(eval_list) + result[key] = result[key].reindex(params.size, [evals], extras=[jt.array(params.stride)]) + else: + result[key] = result[key].reshape(shape) if requires_grad is not None: result[key].requires_grad = requires_grad return result