forked from mindspore-Ecosystem/mindspore
!641 pynative-lamb-op-zeros-like-tensor-query-failed
Merge pull request !641 from JoyLvliang/pynative-lamb-op-zeros-like-tensor-query-failed
This commit is contained in:
commit
c1ef1e0aee
|
@ -86,7 +86,7 @@ def identity(x):
|
|||
def zeros_like_tensor(x):
|
||||
"""Implement `zeros_like_tensor`."""
|
||||
x = x.asnumpy()
|
||||
value = Tensor(np.zeros(x.shape))
|
||||
value = Tensor(np.zeros(x.shape).astype(np.float32))
|
||||
return value
|
||||
|
||||
|
||||
|
|
|
@ -26,7 +26,6 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph)
|
|||
void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@
|
|||
|
||||
const char SINGLE_OP_GRAPH[] = "single_op_graph";
|
||||
// primitive unable to infer value for constant input in PyNative mode
|
||||
const std::unordered_set<std::string> vm_operators = {"partial", "depend", "make_ref"};
|
||||
const std::set<std::string> vm_operators = {"partial", "depend", "make_ref", "zeros_like_tensor"};
|
||||
|
||||
namespace mindspore {
|
||||
namespace pynative {
|
||||
|
|
|
@ -45,6 +45,9 @@ class ScalarSummary(Primitive):
|
|||
def __init__(self):
|
||||
"""init"""
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class ImageSummary(Primitive):
|
||||
"""
|
||||
|
@ -70,6 +73,9 @@ class ImageSummary(Primitive):
|
|||
def __init__(self):
|
||||
"""init"""
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class TensorSummary(Primitive):
|
||||
"""
|
||||
|
@ -97,6 +103,9 @@ class TensorSummary(Primitive):
|
|||
def __init__(self):
|
||||
"""init"""
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class HistogramSummary(Primitive):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue