fix trace

This commit is contained in:
huangmengxi 2021-06-09 11:31:51 +08:00
parent 5bcaae94a7
commit 733158ac31
2 changed files with 8 additions and 1 deletions

View File

@ -25,6 +25,7 @@ from collections.abc import Iterable
import numpy as np
from mindspore import log as logger
from mindspore.common import dtype as mstype
from mindspore._c_expression import Tensor as Tensor_
class Rel(Enum):
"""Numerical relationship between variables, logical relationship enumeration definition of range."""
@ -835,6 +836,11 @@ class Validator:
new_axes += (ax,)
return new_axes
@staticmethod
def empty_compile(dtype, shape):
"""Returns an empty Tensor."""
return Tensor_(dtype, shape)
def check_input_format(input_param):
"""Judge input format."""

View File

@ -777,7 +777,7 @@ def diagonal(x, offset=0, axis1=0, axis2=1):
last_dim_end = min_(
shape[-2], max_(0, shape[-1] - offset)) - last_dim_begin
if last_dim_end <= 0:
return Tensor([])
return empty_compile(dtype, (0,))
size += (last_dim_end,)
res = F.tensor_slice(res, begin, size)
return res.astype(dtype)
@ -1628,6 +1628,7 @@ infer_out_shape = constexpr(validator.infer_out_shape)
get_log2_size = constexpr(validator.get_log2_size)
check_axis_type = constexpr(validator.check_axis_type)
check_and_canonicalize_axes = constexpr(validator.check_and_canonicalize_axes)
empty_compile = constexpr(validator.empty_compile)
def tensor_bool(x):
"""tensor as condition, if is constant, return immediate bool value"""