[MSLITE] fix bug of scatterNdUpdate

This commit is contained in:
wang_shaocong 2021-09-16 19:51:12 +08:00
parent 89888dbebc
commit 2eb3c425d9
3 changed files with 12 additions and 3 deletions

View File

@ -49,8 +49,6 @@ int ScatterNdUpdateCPUKernel::ReSize() {
auto update = in_tensors_.at(kScatterUpdateIndex);
auto output = out_tensors_.front();
update_ptr_ = reinterpret_cast<float *>(update->MutableData());
MS_ASSERT(update_ptr_ != nullptr);
output_ptr_ = reinterpret_cast<float *>(output->MutableData());
MS_ASSERT(output_ptr_ != nullptr);
@ -151,6 +149,14 @@ int ScatterNdUpdateCPUKernel::Run() {
out_tensor->set_own_data(in_tensor->own_data());
output_ptr_ = reinterpret_cast<float *>(out_tensor->data());
}
auto indices = in_tensors_.at(kScatterIndicesIndex);
if (!indices->IsConst() && ReSize() != RET_OK) {
MS_LOG(ERROR) << "ScatterNdUpdate resize failed.";
return RET_ERROR;
}
auto update = in_tensors_.at(kScatterUpdateIndex);
update_ptr_ = reinterpret_cast<float *>(update->MutableData());
MS_ASSERT(update_ptr_ != nullptr);
auto ret = ParallelLaunch(this->ms_context_, ScatterNdUpdateRun, this, thread_n_num_);
if (ret != RET_OK) {

View File

@ -149,7 +149,9 @@ int CaffeConvBaseParser::ParseGroup(const caffe::ConvolutionParameter &convParam
if (convParam.has_group()) {
return convParam.group();
} else {
return layerType == "ConvolutionDepthwise" ? static_cast<int>(convParam.num_output()) : 1;
return layerType == "ConvolutionDepthwise" || layerType == "DepthwiseConv"
? static_cast<int>(convParam.num_output())
: 1;
}
}

View File

@ -92,5 +92,6 @@ ops::PrimitiveC *CaffeConvolutionParser::Parse(const caffe::LayerParameter &prot
}
CaffeNodeRegistrar g_caffeConvolutionParser("Convolution", new CaffeConvolutionParser());
CaffeNodeRegistrar g_caffeDepthwiseConvolutionParser("DepthwiseConv", new CaffeConvolutionParser());
} // namespace lite
} // namespace mindspore