fix trace
This commit is contained in:
parent
5bcaae94a7
commit
733158ac31
|
@ -25,6 +25,7 @@ from collections.abc import Iterable
|
|||
import numpy as np
|
||||
from mindspore import log as logger
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore._c_expression import Tensor as Tensor_
|
||||
|
||||
class Rel(Enum):
|
||||
"""Numerical relationship between variables, logical relationship enumeration definition of range."""
|
||||
|
@ -835,6 +836,11 @@ class Validator:
|
|||
new_axes += (ax,)
|
||||
return new_axes
|
||||
|
||||
@staticmethod
|
||||
def empty_compile(dtype, shape):
|
||||
"""Returns an empty Tensor."""
|
||||
return Tensor_(dtype, shape)
|
||||
|
||||
|
||||
def check_input_format(input_param):
|
||||
"""Judge input format."""
|
||||
|
|
|
@ -777,7 +777,7 @@ def diagonal(x, offset=0, axis1=0, axis2=1):
|
|||
last_dim_end = min_(
|
||||
shape[-2], max_(0, shape[-1] - offset)) - last_dim_begin
|
||||
if last_dim_end <= 0:
|
||||
return Tensor([])
|
||||
return empty_compile(dtype, (0,))
|
||||
size += (last_dim_end,)
|
||||
res = F.tensor_slice(res, begin, size)
|
||||
return res.astype(dtype)
|
||||
|
@ -1628,6 +1628,7 @@ infer_out_shape = constexpr(validator.infer_out_shape)
|
|||
get_log2_size = constexpr(validator.get_log2_size)
|
||||
check_axis_type = constexpr(validator.check_axis_type)
|
||||
check_and_canonicalize_axes = constexpr(validator.check_and_canonicalize_axes)
|
||||
empty_compile = constexpr(validator.empty_compile)
|
||||
|
||||
def tensor_bool(x):
|
||||
"""tensor as condition, if is constant, return immediate bool value"""
|
||||
|
|
Loading…
Reference in New Issue