!49835 Optimize StubTensor shape and dtype

Merge pull request !49835 from NaCN/fix_stub
This commit is contained in:
i-robot 2023-03-07 01:44:34 +00:00 committed by Gitee
commit dbf06f2174
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 2 additions and 4 deletions

View File

@ -53,8 +53,6 @@ class StubTensor:
def __init__(self, stub):
self.stub = stub
self.tensor = None
self.stub_shape = None
self.stub_dtype = None
__str__ = _stub_method(Tensor.__str__)
__setitem__ = _stub_method(Tensor.__setitem__)
@ -77,7 +75,7 @@ class StubTensor:
def shape(self):
"""shape stub."""
if self.stub:
if self.stub_shape is None:
if not hasattr(self, "stub_shape"):
self.stub_shape = self.stub.get_shape()
return self.stub_shape
return self.tensor.shape
@ -86,7 +84,7 @@ class StubTensor:
def dtype(self):
"""dtype stub."""
if self.stub:
if self.stub_dtype is None:
if not hasattr(self, "stub_dtype"):
self.stub_dtype = self.stub.get_dtype()
return self.stub_dtype
return self.tensor.dtype