forked from mindspore-Ecosystem/mindspore
!12015 cpu support scalar tensor
From: @huaweib Reviewed-by: @kisnwang,@zhoufeng54 Signed-off-by: @zhoufeng54
This commit is contained in:
commit
c7494512eb
|
@ -24,6 +24,10 @@ void AssignAddCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
if (src1_shape.size() == 0 && src0_shape.size() == 0) {
|
||||
src0_shape.insert(src0_shape.begin(), 1);
|
||||
src1_shape.insert(src1_shape.begin(), 1);
|
||||
}
|
||||
if (src0_shape.size() != src1_shape.size() && src1_shape.size() > 1) {
|
||||
MS_LOG(EXCEPTION) << "AssignAdd only support same dim input or tensor * scalar " << src0_shape.size() << " vs "
|
||||
<< src1_shape.size();
|
||||
|
|
|
@ -130,7 +130,11 @@ dnnl::memory::format_tag MKLCPUKernel::GetDefaultFormatTag(const dnnl::memory::d
|
|||
|
||||
dnnl::memory::desc MKLCPUKernel::GetDefaultMemDesc(const std::vector<size_t> &shape) {
|
||||
dnnl::memory::dims dims;
|
||||
dims.insert(dims.end(), shape.begin(), shape.end());
|
||||
if (shape.size() == 0) {
|
||||
dims.insert(dims.end(), 1);
|
||||
} else {
|
||||
dims.insert(dims.end(), shape.begin(), shape.end());
|
||||
}
|
||||
dnnl::memory::format_tag mem_tag = GetDefaultFormatTag(dims);
|
||||
dnnl::memory::desc mem_desc(dims, dnnl::memory::data_type::f32, mem_tag);
|
||||
return mem_desc;
|
||||
|
|
|
@ -151,7 +151,7 @@ class CtcLossGpuKernel : public GpuKernel {
|
|||
void LaunchSecondHalf(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
int SOffSet = 2 * max_labels_length_host + 1;
|
||||
const int SOffSet = 2 * max_labels_length_host + 1;
|
||||
int log_prob_size = batch * SOffSet * max_time;
|
||||
|
||||
if (!ignore_longer_outputs_than_inputs_ && max_labels_length_host > max_time) {
|
||||
|
|
Loading…
Reference in New Issue