forked from mindspore-Ecosystem/mindspore
!40990 dynamicshape fixbug
Merge pull request !40990 from TuDouNi/codeclean
This commit is contained in:
commit
cf616c0f05
|
@ -23,7 +23,7 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
void TensorShapeKernelMod::Execute() const {
|
void TensorShapeKernelMod::Execute(void *stream_ptr) const {
|
||||||
MS_LOG(INFO) << "Execute TensorShapeKernel Start";
|
MS_LOG(INFO) << "Execute TensorShapeKernel Start";
|
||||||
auto node = anf_node_.lock();
|
auto node = anf_node_.lock();
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
@ -59,10 +59,10 @@ void TensorShapeKernelMod::Execute() const {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// cppcheck-suppress unreadVariable
|
// cppcheck-suppress unreadVariable
|
||||||
auto lock = device::KernelRuntime::LockRuntime(stream_);
|
auto lock = device::KernelRuntime::LockRuntime(stream_ptr);
|
||||||
auto status =
|
auto status =
|
||||||
rtMemcpyAsync(const_cast<void *>(output_addr->GetPtr()), output_addr->GetSize(), output_tensor_for_sync->data_c(),
|
rtMemcpyAsync(const_cast<void *>(output_addr->GetPtr()), output_addr->GetSize(), output_tensor_for_sync->data_c(),
|
||||||
LongToSize(output_tensor_for_sync->data().nbytes()), RT_MEMCPY_HOST_TO_DEVICE_EX, stream_);
|
LongToSize(output_tensor_for_sync->data().nbytes()), RT_MEMCPY_HOST_TO_DEVICE_EX, stream_ptr);
|
||||||
if (status != RT_ERROR_NONE) {
|
if (status != RT_ERROR_NONE) {
|
||||||
MS_LOG(EXCEPTION) << "Execute TensorShapeKernel rtMemcpyAsync failed!";
|
MS_LOG(EXCEPTION) << "Execute TensorShapeKernel rtMemcpyAsync failed!";
|
||||||
}
|
}
|
||||||
|
@ -78,7 +78,7 @@ bool TensorShapeKernelMod::Launch(const std::vector<AddressPtr> &, const std::ve
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
try {
|
try {
|
||||||
Execute();
|
Execute(stream_ptr);
|
||||||
} catch (const std::exception &e) {
|
} catch (const std::exception &e) {
|
||||||
MS_LOG(ERROR) << "TensorShapeKernelMod Launch failed. node: " << cnode->fullname_with_scope()
|
MS_LOG(ERROR) << "TensorShapeKernelMod Launch failed. node: " << cnode->fullname_with_scope()
|
||||||
<< ", Error message is " << e.what();
|
<< ", Error message is " << e.what();
|
||||||
|
|
|
@ -30,7 +30,7 @@ class TensorShapeKernelMod : public HostKernelMod {
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void Execute() const;
|
void Execute(void *stream_ptr) const;
|
||||||
};
|
};
|
||||||
MS_HOST_REG_KERNEL(DynamicShape, TensorShapeKernelMod);
|
MS_HOST_REG_KERNEL(DynamicShape, TensorShapeKernelMod);
|
||||||
MS_HOST_REG_KERNEL(TensorShape, TensorShapeKernelMod);
|
MS_HOST_REG_KERNEL(TensorShape, TensorShapeKernelMod);
|
||||||
|
|
Loading…
Reference in New Issue