forked from mindspore-Ecosystem/mindspore
commit
b486abc076
|
@ -29,36 +29,7 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
using tag = dnnl::memory::format_tag;
|
||||
using dim = dnnl::memory::dims;
|
||||
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
std::vector<size_t> src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
|
||||
bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional");
|
||||
input_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "input_size");
|
||||
hidden_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "hidden_size");
|
||||
num_layers_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "num_layers");
|
||||
has_bias_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "has_bias");
|
||||
batch_size_ = SizeToInt(src_shape[1]);
|
||||
seq_len_ = SizeToInt(src_shape[0]);
|
||||
num_directions_ = 1;
|
||||
if (bidirectional_) {
|
||||
num_directions_ = 2;
|
||||
}
|
||||
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) {
|
||||
MS_LOG(EXCEPTION) << "error iteration shape!";
|
||||
}
|
||||
if (num_layers_ <= 0) {
|
||||
MS_LOG(EXCEPTION) << "layers must be greater than zero!";
|
||||
}
|
||||
if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) {
|
||||
MS_LOG(EXCEPTION) << "conv2d only support 3-D input!";
|
||||
}
|
||||
const int gate_size = 4 * hidden_size_;
|
||||
for (int i = 0; i < num_layers_; ++i) {
|
||||
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
|
||||
weight_h_size_ += gate_size * hidden_size_;
|
||||
}
|
||||
weight_size_ = weight_size_ * num_directions_;
|
||||
weight_h_size_ = weight_h_size_ * num_directions_;
|
||||
CheckParam(kernel_node);
|
||||
auto eng = MKLKernelEngine::Get().engine();
|
||||
dnnl::stream s(eng);
|
||||
dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional;
|
||||
|
@ -99,6 +70,39 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
AddArgument(DNNL_ARG_WORKSPACE, prim_desc_.workspace_desc());
|
||||
}
|
||||
|
||||
void LstmCPUKernel::CheckParam(const CNodePtr &kernel_node) {
|
||||
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
std::vector<size_t> src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
|
||||
bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional");
|
||||
input_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "input_size");
|
||||
hidden_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "hidden_size");
|
||||
num_layers_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "num_layers");
|
||||
has_bias_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "has_bias");
|
||||
batch_size_ = SizeToInt(src_shape[1]);
|
||||
seq_len_ = SizeToInt(src_shape[0]);
|
||||
num_directions_ = 1;
|
||||
if (bidirectional_) {
|
||||
num_directions_ = 2;
|
||||
}
|
||||
const int gate_size = 4 * hidden_size_;
|
||||
for (int i = 0; i < num_layers_; ++i) {
|
||||
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
|
||||
weight_h_size_ += gate_size * hidden_size_;
|
||||
}
|
||||
weight_size_ = weight_size_ * num_directions_;
|
||||
weight_h_size_ = weight_h_size_ * num_directions_;
|
||||
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) {
|
||||
MS_LOG(EXCEPTION) << "error iteration shape!";
|
||||
}
|
||||
if (num_layers_ <= 0) {
|
||||
MS_LOG(EXCEPTION) << "layers must be greater than zero!";
|
||||
}
|
||||
if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) {
|
||||
MS_LOG(EXCEPTION) << "lstm only support 3-D input!";
|
||||
}
|
||||
}
|
||||
|
||||
bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
|
|
|
@ -37,6 +37,7 @@ class LstmCPUKernel : public MKLCPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
int weight_size_ = 0;
|
||||
int weight_h_size_ = 0;
|
||||
int input_size_;
|
||||
|
|
|
@ -28,37 +28,8 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
using tag = dnnl::memory::format_tag;
|
||||
using dim = dnnl::memory::dims;
|
||||
CheckParam(kernel_node);
|
||||
auto eng = MKLKernelEngine::Get().engine();
|
||||
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
std::vector<size_t> src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
|
||||
bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional");
|
||||
input_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "input_size");
|
||||
hidden_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "hidden_size");
|
||||
num_layers_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "num_layers");
|
||||
has_bias_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "has_bias");
|
||||
batch_size_ = SizeToInt(src_shape[1]);
|
||||
seq_len_ = SizeToInt(src_shape[0]);
|
||||
num_directions_ = 1;
|
||||
if (bidirectional_) {
|
||||
num_directions_ = 2;
|
||||
}
|
||||
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) {
|
||||
MS_LOG(EXCEPTION) << "error iteration shape!";
|
||||
}
|
||||
if (num_layers_ <= 0) {
|
||||
MS_LOG(EXCEPTION) << "layers must be greater than zero!";
|
||||
}
|
||||
if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) {
|
||||
MS_LOG(EXCEPTION) << "conv2d only support 3-D input!";
|
||||
}
|
||||
const int gate_size = 4 * hidden_size_;
|
||||
for (int i = 0; i < num_layers_; ++i) {
|
||||
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
|
||||
weight_h_size_ += gate_size * hidden_size_;
|
||||
}
|
||||
weight_size_ = weight_size_ * num_directions_;
|
||||
weight_h_size_ = weight_h_size_ * num_directions_;
|
||||
dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional;
|
||||
if (bidirectional_) {
|
||||
direction = dnnl::rnn_direction::bidirectional_concat;
|
||||
|
@ -91,7 +62,14 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
dst_h_desc, dst_c_desc);
|
||||
prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(*backward_desc, eng, prim_forward_desc);
|
||||
primitive_ = std::make_shared<dnnl::lstm_backward>(prim_backward_desc_);
|
||||
AddArgument(DNNL_ARG_WORKSPACE, prim_forward_desc.workspace_desc());
|
||||
AddArgumentOp(src_desc, src_h_desc, src_c_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc);
|
||||
}
|
||||
|
||||
void LSTMGradCPUKernel::AddArgumentOp(const dnnl::memory::desc &src_desc, const dnnl::memory::desc &src_h_desc,
|
||||
const dnnl::memory::desc &src_c_desc, const dnnl::memory::desc &bias_desc,
|
||||
const dnnl::memory::desc &dst_desc, const dnnl::memory::desc &dst_h_desc,
|
||||
const dnnl::memory::desc &dst_c_desc) {
|
||||
AddArgument(DNNL_ARG_SRC_LAYER, src_desc);
|
||||
AddArgument(DNNL_ARG_SRC_ITER, src_h_desc);
|
||||
AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc);
|
||||
|
@ -101,7 +79,6 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
AddArgument(DNNL_ARG_DST_LAYER, dst_desc);
|
||||
AddArgument(DNNL_ARG_DST_ITER, dst_h_desc);
|
||||
AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc);
|
||||
AddArgument(DNNL_ARG_WORKSPACE, prim_forward_desc.workspace_desc());
|
||||
AddArgument(DNNL_ARG_DIFF_SRC_LAYER, src_desc);
|
||||
AddArgument(DNNL_ARG_DIFF_SRC_ITER, src_h_desc);
|
||||
AddArgument(DNNL_ARG_DIFF_SRC_ITER_C, src_c_desc);
|
||||
|
@ -113,6 +90,72 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
AddArgument(DNNL_ARG_DIFF_DST_ITER_C, dst_c_desc);
|
||||
}
|
||||
|
||||
void LSTMGradCPUKernel::CheckParam(const CNodePtr &kernel_node) {
|
||||
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
std::vector<size_t> src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
|
||||
bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional");
|
||||
input_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "input_size");
|
||||
hidden_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "hidden_size");
|
||||
num_layers_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "num_layers");
|
||||
has_bias_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "has_bias");
|
||||
batch_size_ = SizeToInt(src_shape[1]);
|
||||
seq_len_ = SizeToInt(src_shape[0]);
|
||||
num_directions_ = 1;
|
||||
if (bidirectional_) {
|
||||
num_directions_ = 2;
|
||||
}
|
||||
const int gate_size = 4 * hidden_size_;
|
||||
for (int i = 0; i < num_layers_; ++i) {
|
||||
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
|
||||
weight_h_size_ += gate_size * hidden_size_;
|
||||
}
|
||||
weight_size_ = weight_size_ * num_directions_;
|
||||
weight_h_size_ = weight_h_size_ * num_directions_;
|
||||
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) {
|
||||
MS_LOG(EXCEPTION) << "error iteration shape!";
|
||||
}
|
||||
if (num_layers_ <= 0) {
|
||||
MS_LOG(EXCEPTION) << "layers must be greater than zero!";
|
||||
}
|
||||
if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) {
|
||||
MS_LOG(EXCEPTION) << "lstm only support 3-D input!";
|
||||
}
|
||||
}
|
||||
|
||||
void LSTMGradCPUKernel::SetArgumentHandleOp(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs,
|
||||
const dnnl::memory &weights_memory, const dnnl::memory &weights_h_memory,
|
||||
const dnnl::memory &bias_memory, const dnnl::memory &diff_weights_memory,
|
||||
const dnnl::memory &diff_weights_h_memory,
|
||||
const dnnl::memory &diff_bias_memory) {
|
||||
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle());
|
||||
SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle());
|
||||
SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle());
|
||||
SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[4]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[5]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DST_ITER_C, inputs[6]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_WORKSPACE, inputs[10]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_SRC_LAYER, outputs[0]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER, outputs[1]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER_C, outputs[2]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_memory.get_data_handle());
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_h_memory.get_data_handle());
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_BIAS, diff_bias_memory.get_data_handle());
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[7]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[8]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER_C, inputs[9]->addr);
|
||||
}
|
||||
|
||||
void LSTMGradCPUKernel::Memset_op(const dnnl::memory &mem, string name) {
|
||||
if (memset_s(mem.get_data_handle(), mem.get_desc().get_size(), 0, mem.get_desc().get_size())) {
|
||||
MS_LOG(EXCEPTION) << name << " memset error";
|
||||
}
|
||||
}
|
||||
|
||||
bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
|
@ -145,14 +188,10 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
auto user_diff_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng);
|
||||
user_diff_weights_memory.set_data_handle(outputs[3]->addr);
|
||||
user_diff_weights_h_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_);
|
||||
if (memset_s(user_diff_weights_memory.get_data_handle(), user_diff_weights_memory.get_desc().get_size(), 0,
|
||||
user_diff_weights_memory.get_desc().get_size())) {
|
||||
MS_LOG(EXCEPTION) << "user weights grad memset error";
|
||||
}
|
||||
if (memset_s(user_diff_weights_h_memory.get_data_handle(), user_diff_weights_h_memory.get_desc().get_size(), 0,
|
||||
user_diff_weights_h_memory.get_desc().get_size())) {
|
||||
MS_LOG(EXCEPTION) << "user weights iter grad memset error";
|
||||
}
|
||||
Memset_op(user_diff_weights_memory, "user weights grad");
|
||||
Memset_op(user_diff_weights_h_memory, "user weights iter grad");
|
||||
Memset_op(diff_weights_memory, "weights grad");
|
||||
Memset_op(diff_weights_h_memory, "weights iter grad");
|
||||
if (has_bias_) {
|
||||
diff_bias_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_ + weight_h_size_);
|
||||
}
|
||||
|
@ -160,33 +199,8 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
prim_backward_desc_.diff_bias_desc().get_size())) {
|
||||
MS_LOG(EXCEPTION) << "bias grad memset error";
|
||||
}
|
||||
if (memset_s(diff_weights_memory.get_data_handle(), diff_weights_memory.get_desc().get_size(), 0,
|
||||
diff_weights_memory.get_desc().get_size())) {
|
||||
MS_LOG(EXCEPTION) << "weights grad memset error";
|
||||
}
|
||||
if (memset_s(diff_weights_h_memory.get_data_handle(), diff_weights_h_memory.get_desc().get_size(), 0,
|
||||
diff_weights_h_memory.get_desc().get_size())) {
|
||||
MS_LOG(EXCEPTION) << "weights iter grad memset error";
|
||||
}
|
||||
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle());
|
||||
SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle());
|
||||
SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle());
|
||||
SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[4]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[5]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DST_ITER_C, inputs[6]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_WORKSPACE, inputs[10]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_SRC_LAYER, outputs[0]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER, outputs[1]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER_C, outputs[2]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_memory.get_data_handle());
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_h_memory.get_data_handle());
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_BIAS, diff_bias_memory.get_data_handle());
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[7]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[8]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER_C, inputs[9]->addr);
|
||||
SetArgumentHandleOp(inputs, outputs, weights_memory, weights_h_memory, bias_memory, diff_weights_memory,
|
||||
diff_weights_h_memory, diff_bias_memory);
|
||||
ExecutePrimitive();
|
||||
Reorder(&diff_weights_memory, &user_diff_weights_memory);
|
||||
Reorder(&diff_weights_h_memory, &user_diff_weights_h_memory);
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LSTM_GRAD_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LSTM_GRAD_CPU_KERNEL_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
|
||||
|
@ -32,6 +33,17 @@ class LSTMGradCPUKernel : public MKLCPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
void AddArgumentOp(const dnnl::memory::desc &src_desc, const dnnl::memory::desc &src_h_desc,
|
||||
const dnnl::memory::desc &src_c_desc, const dnnl::memory::desc &bias_desc,
|
||||
const dnnl::memory::desc &dst_desc, const dnnl::memory::desc &dst_h_desc,
|
||||
const dnnl::memory::desc &dst_c_desc);
|
||||
void SetArgumentHandleOp(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs, const dnnl::memory &weights_memory,
|
||||
const dnnl::memory &weights_h_memory, const dnnl::memory &bias_memory,
|
||||
const dnnl::memory &diff_weights_memory, const dnnl::memory &diff_weights_h_memory,
|
||||
const dnnl::memory &diff_bias_memory);
|
||||
void Memset_op(const dnnl::memory &mem, string name);
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
int weight_size_ = 0;
|
||||
int weight_h_size_ = 0;
|
||||
int input_size_;
|
||||
|
|
|
@ -79,9 +79,6 @@ void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
bool ReduceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspaces*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (inputs.empty() || outputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "input or output empty!";
|
||||
}
|
||||
size_t out_float_size = left_dims_ * sizeof(float);
|
||||
size_t in_float_size = stride_ * out_float_size;
|
||||
if (inputs[0]->size != in_float_size || outputs[0]->size != out_float_size) {
|
||||
|
@ -106,6 +103,11 @@ bool ReduceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
}
|
||||
(void)transpose_axis.insert(transpose_axis.end(), axis_.begin(), axis_.end());
|
||||
Transpose(size, input, shape_, transpose_axis, SizeToInt(shape_.size()), &new_input[0]);
|
||||
ConvertDataToOutput(&new_input[0], output);
|
||||
return true;
|
||||
}
|
||||
|
||||
void ReduceCPUKernel::ConvertDataToOutput(const float *new_input, float *output) {
|
||||
if (reduce_type_ == kReduceTypeMax) {
|
||||
for (size_t i = 0; i < left_dims_; ++i) {
|
||||
float value = new_input[i * stride_];
|
||||
|
@ -129,7 +131,6 @@ bool ReduceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
void ReduceCPUKernel::Transpose(const int size, const float *input, const std::vector<size_t> &input_shape,
|
||||
const std::vector<size_t> &input_axis, const int shape_size, float *output) {
|
||||
|
|
|
@ -34,7 +34,8 @@ class ReduceCPUKernel : public CPUKernel {
|
|||
private:
|
||||
void Transpose(const int size, const float *input, const std::vector<size_t> &input_shape,
|
||||
const std::vector<size_t> &input_axis, const int shape_size, float *output);
|
||||
size_t reduce_type_;
|
||||
void ConvertDataToOutput(const float *input, float *output);
|
||||
size_t reduce_type_ = 0;
|
||||
std::vector<size_t> axis_;
|
||||
std::vector<size_t> shape_;
|
||||
size_t left_dims_ = 1;
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#define PATH_MAX 0x3ffff
|
||||
#define PATH_MAX 4096
|
||||
#include "runtime/device/ascend/ascend_kernel_runtime.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
|
Loading…
Reference in New Issue