Fixing bugs

This commit is contained in:
zhangz0911gm 2020-08-12 03:52:08 -04:00
parent a7157c42eb
commit a80e0cd111
4 changed files with 7 additions and 8 deletions

View File

@ -179,9 +179,9 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *src_prim) {
case schema::PrimitiveType_StridedSlice:
return new lite::StridedSlice(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Prelu:
return new lite::Prelu(const_cast<schema::Primitive *>(srcPrim));
return new lite::Prelu(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_CaffePReLU:
return new lite::CaffePReLU(const_cast<schema::Primitive *>(srcPrim));
return new lite::CaffePReLU(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Round:
return new lite::Round(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Reverse:

View File

@ -31,8 +31,7 @@ class PreluBaseCPUKernel : public LiteKernel {
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
const lite::Primitive *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx) {
opParameter->thread_num_ = ctx->thread_num_;
prelu_param_ = reinterpret_cast<PreluParameter *>(opParameter);
prelu_param_ = reinterpret_cast<PreluParameter *>(op_parameter_);
}
~PreluBaseCPUKernel() = default;

View File

@ -51,12 +51,12 @@ int CaffePReluCPUKernel::Run() {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
auto input = inputs_.at(0);
auto input1 = inputs_.at(1);
auto input = in_tensors_[0];
auto input1 = in_tensors_[1];
prelu_param_->input_num_ = input->ElementsNum();
input_data = reinterpret_cast<float *>(input->Data());
output_data = reinterpret_cast<float *>(outputs_.at(0)->Data());
output_data = reinterpret_cast<float *>(out_tensors_[0]->Data());
auto channels = input->shape();
prelu_param_->negtive_slope_ = reinterpret_cast<float *>(input1->Data());
prelu_param_->channel_num_ = channels.at(channels.size() - 1);

View File

@ -32,7 +32,7 @@ class CaffePReluCPUKernel : public LiteKernel {
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const lite::Primitive *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) {
prelu_param_ = (reinterpret_cast<CaffePReluParameter *>(opParameter));
prelu_param_ = reinterpret_cast<CaffePReluParameter *>(op_parameter_);
primitive_ = primitive;
}
~CaffePReluCPUKernel() = default;