forked from mindspore-Ecosystem/mindspore
[MSLITE][DEVELOP] add npu op fullconnection, reduce_mean; add npu testcases
This commit is contained in:
parent
4fcdcb59a3
commit
7fa7b9d23b
|
@ -21,6 +21,9 @@
|
|||
|
||||
namespace mindspore::lite {
|
||||
bool CheckFusion(kernel::LiteKernel *kernel) {
|
||||
if (kernel->in_kernels().empty() || kernel->out_kernels().empty()) {
|
||||
return false;
|
||||
}
|
||||
auto pre_flag =
|
||||
std::all_of(kernel->in_kernels().begin(), kernel->in_kernels().end(), [](const kernel::LiteKernel *in_kernel) {
|
||||
return NPUPassUtils::IsNchw2Nhwc(in_kernel) && in_kernel->out_kernels().size() == 1;
|
||||
|
|
|
@ -34,7 +34,7 @@ ConvolutionBaseNPUKernel::~ConvolutionBaseNPUKernel() {
|
|||
}
|
||||
}
|
||||
|
||||
int ConvolutionBaseNPUKernel::InitWeightBiasConst(const std::vector<lite::Tensor *> &inputs) {
|
||||
int ConvolutionBaseNPUKernel::InitWeightConst(const std::vector<lite::Tensor *> &inputs) {
|
||||
weight_ = new (std::nothrow) hiai::op::Const(name_ + "_w");
|
||||
if (weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "New weight const failed.";
|
||||
|
@ -61,7 +61,10 @@ int ConvolutionBaseNPUKernel::InitWeightBiasConst(const std::vector<lite::Tensor
|
|||
|
||||
weight_->set_attr_value(weight_tensor);
|
||||
free(nchw_data);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionBaseNPUKernel::InitBiasConst(const std::vector<lite::Tensor *> &inputs) {
|
||||
if (inputs.size() >= 3) {
|
||||
bias_ = new (std::nothrow) hiai::op::Const(name_ + "_b");
|
||||
if (bias_ == nullptr) {
|
||||
|
@ -88,7 +91,7 @@ int ConvolutionBaseNPUKernel::SetActivation(const ge::Operator *input, ActType a
|
|||
} else if (act_type == ActType_Relu6) {
|
||||
act_->set_attr_mode(14);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupport activation for convolution.";
|
||||
MS_LOG(ERROR) << "Unsupport activation type for convolution.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
|
|
|
@ -32,7 +32,8 @@ class ConvolutionBaseNPUKernel : public NPUKernel {
|
|||
~ConvolutionBaseNPUKernel() override;
|
||||
|
||||
protected:
|
||||
int InitWeightBiasConst(const std::vector<lite::Tensor *> &inputs);
|
||||
int InitWeightConst(const std::vector<lite::Tensor *> &inputs);
|
||||
int InitBiasConst(const std::vector<lite::Tensor *> &inputs);
|
||||
int SetActivation(const ge::Operator *input, ActType act_type);
|
||||
hiai::op::Activation *act_ = nullptr;
|
||||
hiai::op::Const *weight_ = nullptr;
|
||||
|
|
|
@ -39,7 +39,7 @@ int ConvolutionDepthwiseNPUKernel::SetConvDwParam() {
|
|||
conv_dw_->set_attr_pad_mode(ge::AttrValue::STR{"VALID"});
|
||||
conv_dw_->set_attr_pads(ge::AttrValue::LIST_INT({0, 0, 0, 0}));
|
||||
} else {
|
||||
conv_dw_->set_attr_pad_mode(ge::AttrValue::STR{"SPECIFIC"});
|
||||
conv_dw_->set_attr_pad_mode(ge::AttrValue::STR{"VALID"});
|
||||
conv_dw_->set_attr_pads(
|
||||
ge::AttrValue::LIST_INT({conv_param_->pad_u_, conv_param_->pad_d_, conv_param_->pad_l_, conv_param_->pad_r_}));
|
||||
}
|
||||
|
@ -61,13 +61,19 @@ int ConvolutionDepthwiseNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = InitWeightBiasConst(inputs);
|
||||
ret = InitWeightConst(inputs);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Set weight and bias for convolution depthwise op " << name_ << " failed when running npu";
|
||||
return RET_ERROR;
|
||||
}
|
||||
conv_dw_->set_input_filter(*weight_);
|
||||
|
||||
if (inputs.size() == 3) {
|
||||
ret = InitBiasConst(inputs);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Set bias for convolution depthwise op " << name_ << " failed when running npu";
|
||||
return RET_ERROR;
|
||||
}
|
||||
conv_dw_->set_input_bias(*bias_);
|
||||
}
|
||||
conv_dw_->set_input_x(*npu_inputs[0]);
|
||||
|
|
|
@ -65,13 +65,19 @@ int ConvolutionNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = InitWeightBiasConst(inputs);
|
||||
ret = InitWeightConst(inputs);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Set weight and bias for convolution op " << name_ << " failed when running npu";
|
||||
return RET_ERROR;
|
||||
}
|
||||
conv_->set_input_filter(*weight_);
|
||||
|
||||
if (inputs.size() == 3) {
|
||||
ret = InitBiasConst(inputs);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Set bias for convolution op " << name_ << " failed when running npu";
|
||||
return RET_ERROR;
|
||||
}
|
||||
conv_->set_input_bias(*bias_);
|
||||
}
|
||||
conv_->set_input_x(*npu_inputs[0]);
|
||||
|
|
|
@ -65,13 +65,19 @@ int DeconvolutionNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inpu
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = InitWeightBiasConst(inputs);
|
||||
ret = InitWeightConst(inputs);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Set weight and bias for deconvolution op " << name_ << " failed when running npu";
|
||||
return RET_ERROR;
|
||||
}
|
||||
deconv_->set_input_filter(*weight_);
|
||||
|
||||
if (inputs.size() == 3) {
|
||||
ret = InitBiasConst(inputs);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Set bias for deconvolution op " << name_ << " failed when running npu";
|
||||
return RET_ERROR;
|
||||
}
|
||||
deconv_->set_input_bias(*bias_);
|
||||
}
|
||||
deconv_->set_input_x(*npu_inputs[0]);
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/npu/fullconnection_npu.h"
|
||||
#include <memory>
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/agent/npu/npu_converter_utils.h"
|
||||
using mindspore::kernel::KERNEL_ARCH::kNPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::schema::PrimitiveType_FullConnection;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int FullconnectionNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int FullconnectionNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs,
|
||||
const std::vector<ge::Operator *> &npu_inputs) {
|
||||
auto input_shape = inputs[0]->shape();
|
||||
reshape_ = new (std::nothrow) hiai::op::Reshape(name_ + "_reshape");
|
||||
if (reshape_ == nullptr) {
|
||||
MS_LOG(ERROR) << "New reshape operator for fullconnection op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
reshape_->set_input_x(*npu_inputs[0]);
|
||||
int col = 1;
|
||||
for (int i = 1; i < input_shape.size(); i++) {
|
||||
col *= input_shape[i];
|
||||
}
|
||||
auto reshape_op = new (std::nothrow) hiai::op::Const(name_ + "_reshape_data");
|
||||
vector<int> reshape_data = {input_shape[0], col};
|
||||
ge::TensorDesc reshape_tensor_desc(ge::Shape({2}), ge::FORMAT_NCHW, ge::DT_FLOAT);
|
||||
ge::TensorPtr reshape_tensor = std::make_shared<hiai::Tensor>(reshape_tensor_desc);
|
||||
reshape_tensor->SetData(reinterpret_cast<uint8_t *>(reshape_data.data()), 2 * sizeof(float));
|
||||
reshape_op->set_attr_value(reshape_tensor);
|
||||
reshape_->set_input_shape(*reshape_op);
|
||||
|
||||
fc_ = new (std::nothrow) hiai::op::MatMul(name_);
|
||||
if (fc_ == nullptr) {
|
||||
MS_LOG(ERROR) << "New matmul operator for fullconnection op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
fc_->set_input_x1(*reshape_);
|
||||
|
||||
weight_ = new (std::nothrow) hiai::op::Const(name_ + "_w");
|
||||
if (weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "New weight const failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
inputs[1]->set_format(schema::Format_NCHW);
|
||||
auto weight_tensor = mindspore::lite::ConverterToNPUTensor(inputs[1]);
|
||||
weight_->set_attr_value(weight_tensor);
|
||||
inputs[1]->set_format(schema::Format_NHWC);
|
||||
fc_->set_input_x2(*weight_).set_attr_transpose_x2(true);
|
||||
|
||||
if (fc_param_->has_bias_) {
|
||||
biasadd_ = new (std::nothrow) hiai::op::BiasAdd(name_ + "_biasadd");
|
||||
if (biasadd_ == nullptr) {
|
||||
MS_LOG(ERROR) << "New biasadd operator for fullconnection op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto ret = InitBiasConst(inputs);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Set bias for convolution op " << name_ << " failed when running npu";
|
||||
return RET_ERROR;
|
||||
}
|
||||
biasadd_->set_input_x(*fc_).set_input_bias(*bias_);
|
||||
}
|
||||
|
||||
if (fc_param_->act_type_ != ActType_No) {
|
||||
auto ret =
|
||||
biasadd_ == nullptr ? SetActivation(fc_, fc_param_->act_type_) : SetActivation(biasadd_, fc_param_->act_type_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "New activation npu operator for op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
ge::Operator *mindspore::kernel::FullconnectionNPUKernel::GetNPUOp() {
|
||||
if (fc_param_->act_type_ != ActType_No) {
|
||||
return act_;
|
||||
}
|
||||
if (fc_param_->has_bias_) {
|
||||
return biasadd_;
|
||||
}
|
||||
return fc_;
|
||||
}
|
||||
|
||||
FullconnectionNPUKernel::~FullconnectionNPUKernel() {
|
||||
if (reshape_ != nullptr) {
|
||||
delete reshape_;
|
||||
reshape_ = nullptr;
|
||||
}
|
||||
if (fc_ != nullptr) {
|
||||
delete fc_;
|
||||
fc_ = nullptr;
|
||||
}
|
||||
if (biasadd_ != nullptr) {
|
||||
delete biasadd_;
|
||||
biasadd_ = nullptr;
|
||||
}
|
||||
}
|
||||
REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_FullConnection, NPUKernelCreator<FullconnectionNPUKernel>)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_FULLCONNECTION_NPU_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_FULLCONNECTION_NPU_H_
|
||||
#include <vector>
|
||||
#include "src/runtime/kernel/npu/convolution_base_npu.h"
|
||||
#include "include/graph/op/all_ops.h"
|
||||
#include "nnacl/matmul_parameter.h"
|
||||
namespace mindspore::kernel {
|
||||
class FullconnectionNPUKernel : public ConvolutionBaseNPUKernel {
|
||||
public:
|
||||
FullconnectionNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: ConvolutionBaseNPUKernel(parameter, inputs, outputs, ctx, primitive) {
|
||||
fc_param_ = reinterpret_cast<MatMulParameter *>(parameter);
|
||||
}
|
||||
~FullconnectionNPUKernel() override;
|
||||
|
||||
int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
|
||||
OpParameter *opParameter) override;
|
||||
int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
|
||||
const std::vector<ge::Operator *> &npu_inputs) override;
|
||||
ge::Operator *GetNPUOp() override;
|
||||
|
||||
private:
|
||||
hiai::op::Reshape *reshape_ = nullptr;
|
||||
hiai::op::MatMul *fc_ = nullptr;
|
||||
hiai::op::BiasAdd *biasadd_ = nullptr;
|
||||
MatMulParameter *fc_param_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_FULLCONNECTION_NPU_H_
|
|
@ -0,0 +1,72 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/npu/reduce_npu.h"
|
||||
#include <memory>
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/graph/op/all_ops.h"
|
||||
#include "src/runtime/agent/npu/npu_converter_utils.h"
|
||||
using mindspore::kernel::KERNEL_ARCH::kNPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::schema::PrimitiveType_Reduce;
|
||||
using mindspore::schema::ReduceMode_ReduceMean;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int ReduceNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
|
||||
OpParameter *opParameter) {
|
||||
if (reduce_param_->mode_ != ReduceMode_ReduceMean) {
|
||||
MS_LOG(ERROR) << "Npu does not support reduce mode " << reduce_param_->mode_ << " for op " << name_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (reduce_param_->reduce_to_end_) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ReduceNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
|
||||
const std::vector<ge::Operator *> &npu_inputs) {
|
||||
std::vector<int32_t> axes;
|
||||
for (int i = 0; i < reduce_param_->num_axes_; i++) {
|
||||
axes.push_back(reduce_param_->axes_[i]);
|
||||
}
|
||||
auto axes_op = new (std::nothrow) hiai::op::Const(name_ + "_reduce_axes");
|
||||
ge::TensorDesc axes_tensor_desc(ge::Shape({reduce_param_->num_axes_}), ge::FORMAT_NCHW, ge::DT_INT32);
|
||||
ge::TensorPtr axes_tensor = std::make_shared<hiai::Tensor>(axes_tensor_desc);
|
||||
axes_tensor->SetData(reinterpret_cast<uint8_t *>(axes.data()), reduce_param_->num_axes_ * sizeof(int32_t));
|
||||
axes_op->set_attr_value(axes_tensor);
|
||||
|
||||
auto reduce_mean_ = new (std::nothrow) hiai::op::ReduceMean(name_);
|
||||
if (reduce_mean_ == nullptr) {
|
||||
MS_LOG(ERROR) << "New reduce operator for op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
reduce_mean_->set_input_x(*npu_inputs[0]).set_input_axes(*axes_op).set_attr_keep_dims(reduce_param_->keep_dims_);
|
||||
reduce_ = reduce_mean_;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
ge::Operator *mindspore::kernel::ReduceNPUKernel::GetNPUOp() { return this->reduce_; }
|
||||
|
||||
ReduceNPUKernel::~ReduceNPUKernel() {
|
||||
if (reduce_ != nullptr) {
|
||||
delete reduce_;
|
||||
reduce_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Reduce, NPUKernelCreator<ReduceNPUKernel>)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_REDUCE_NPU_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_REDUCE_NPU_H_
|
||||
#include <vector>
|
||||
#include "nnacl/reduce_parameter.h"
|
||||
#include "src/runtime/kernel/npu/npu_kernel.h"
|
||||
#include "include/graph/op/all_ops.h"
|
||||
namespace mindspore::kernel {
|
||||
class ReduceNPUKernel : public NPUKernel {
|
||||
public:
|
||||
ReduceNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: NPUKernel(parameter, inputs, outputs, ctx, primitive) {
|
||||
reduce_param_ = reinterpret_cast<ReduceParameter *>(parameter);
|
||||
}
|
||||
~ReduceNPUKernel() override;
|
||||
|
||||
int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
|
||||
OpParameter *opParameter) override;
|
||||
int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
|
||||
const std::vector<ge::Operator *> &npu_inputs) override;
|
||||
ge::Operator *GetNPUOp() override;
|
||||
|
||||
private:
|
||||
ReduceParameter *reduce_param_;
|
||||
hiai::Operator *reduce_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_REDUCE_NPU_H_
|
|
@ -68,3 +68,4 @@ ml_location_scene_division
|
|||
ml_tabel_recog
|
||||
ml_text_division
|
||||
6c_seg_nomean_20200610
|
||||
ml_video_edit_person_divison
|
||||
|
|
|
@ -1,5 +1,28 @@
|
|||
mobilenet_v1_0.25_128.tflite 2.5
|
||||
mobilenet_v1_0.25_160.tflite 2.5
|
||||
mobilenet_v1_0.25_192.tflite 1.5
|
||||
mobilenet_v1_0.25_224.tflite 2
|
||||
mobilenet_v1_0.5_128.tflite 2
|
||||
mobilenet_v1_0.5_160.tflite 2
|
||||
mobilenet_v1_0.5_192.tflite 2.5
|
||||
mobilenet_v1_0.5_224.tflite 2
|
||||
mobilenet_v1_0.75_128.tflite 3
|
||||
mobilenet_v1_0.75_160.tflite 3
|
||||
mobilenet_v1_0.75_192.tflite 3.5
|
||||
mobilenet_v1_0.75_224.tflite 1.5
|
||||
mobilenet_v1_1.0_128.tflite 6
|
||||
mobilenet_v1_1.0_160.tflite 2
|
||||
mobilenet_v1_1.0_192.tflite 6
|
||||
mobilenet_v1_1.0_224.tflite 2.5
|
||||
mobilenet_v2_1.0_224.tflite 2.5
|
||||
squeezenet.tflite 2.5
|
||||
inception_v3.tflite 1
|
||||
inception_v4.tflite 0.5
|
||||
efficientnet_lite0_fp32_2.tflite 1
|
||||
efficientnet_lite1_fp32_2.tflite 1
|
||||
efficientnet_lite2_fp32_2.tflite 1
|
||||
efficientnet_lite3_fp32_2.tflite 1
|
||||
efficientnet_lite4_fp32_2.tflite 1
|
||||
6c_seg_nomean_20200610 1.5
|
||||
ml_video_edit_person_divison 0.5
|
||||
porseg_tmp.onnx 1 2
|
||||
|
|
Loading…
Reference in New Issue