Merge pull request !4628 from yeyunpeng2020/master_cops_3
This commit is contained in:
mindspore-ci-bot 2020-08-18 16:59:00 +08:00 committed by Gitee
commit b1cfb6d627
574 changed files with 3771 additions and 7812 deletions

View File

@ -1,3 +0,0 @@
file(GLOB_RECURSE C_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc)
add_library(c_ops_mid OBJECT ${C_OPS_SRC})

View File

@ -1,59 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/addn.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int AddN::GetN() const { return this->primitive->value.AsAddN()->N; }
void AddN::SetN(int n) { this->primitive->value.AsAddN()->N = n; }
#else
int AddN::GetN() const { return this->primitive->value_as_AddN()->N(); }
void AddN::SetN(int n) {}
#endif
namespace {
constexpr int kLeastInputNum = 2;
}
int AddN::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs.front();
MS_ASSERT(input != nullptr);
auto output = outputs.front();
MS_ASSERT(output != nullptr);
if (inputs.size() < kLeastInputNum) {
MS_LOG(ERROR) << "input size" << inputs.size() << " is error!";
return 1;
}
for (int i = 1; i < inputs.size(); ++i) {
if (inputs.at(i)->shape() != inputs.at(0)->shape()) {
MS_LOG(ERROR) << "AddN inputs shape is not equal!";
return 1;
}
if (inputs.at(i)->data_type() != inputs.at(0)->data_type()) {
MS_LOG(ERROR) << "AddN all input data type should be the same!";
return 1;
}
}
output->SetFormat(input->GetFormat());
output->set_shape(input->shape());
output->set_data_type(input->data_type());
return 0;
}
} // namespace mindspore

View File

@ -1,75 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/argmax.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int ArgMax::GetAxis() const { return this->primitive->value.AsArgMax()->axis; }
bool ArgMax::GetOutMaxValue() const { return this->primitive->value.AsArgMax()->outMaxValue; }
int ArgMax::GetTopK() const { return this->primitive->value.AsArgMax()->topK; }
bool ArgMax::GetKeepDims() const { return this->primitive->value.AsArgMax()->keepDims; }
int ArgMax::GetAxisType() const { return this->primitive->value.AsArgMax()->axisType; }
void ArgMax::SetAxis(int axis) { this->primitive->value.AsArgMax()->axis = axis; }
void ArgMax::SetOutMaxValue(bool out_max_value) { this->primitive->value.AsArgMax()->outMaxValue = out_max_value; }
void ArgMax::SetTopK(int top_k) { this->primitive->value.AsArgMax()->topK = top_k; }
void ArgMax::SetKeepDims(bool keep_dims) { this->primitive->value.AsArgMax()->keepDims = keep_dims; }
void ArgMax::SetAxisType(int axis_type) { this->primitive->value.AsArgMax()->axisType = axis_type; }
#else
int ArgMax::GetAxis() const { return this->primitive->value_as_ArgMax()->axis(); }
bool ArgMax::GetOutMaxValue() const { return this->primitive->value_as_ArgMax()->outMaxValue(); }
int ArgMax::GetTopK() const { return this->primitive->value_as_ArgMax()->topK(); }
bool ArgMax::GetKeepDims() const { return this->primitive->value_as_ArgMax()->keepDims(); }
int ArgMax::GetAxisType() const { return this->primitive->value_as_ArgMax()->axisType(); }
void ArgMax::SetAxis(int axis) {}
void ArgMax::SetOutMaxValue(bool out_max_value) {}
void ArgMax::SetTopK(int top_k) {}
void ArgMax::SetKeepDims(bool keep_dims) {}
void ArgMax::SetAxisType(int axis_type) {}
#endif
int ArgMax::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "tensor number is error.";
}
std::vector<int> output_shape(input->shape());
auto input_shape_size = input->shape().size();
int axis = GetAxis() < 0 ? GetAxis() + input_shape_size : GetAxis();
if (axis >= input_shape_size || axis < 0) {
MS_LOG(ERROR) << "Invalid axis " << GetAxis() << ", input shape size: " << input_shape_size;
return 1;
}
if (GetTopK() == 1 && !GetKeepDims()) {
output_shape.erase(output_shape.begin() + axis);
} else {
output_shape[axis] = GetTopK();
}
output->SetFormat(input->GetFormat());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
return 0;
}
} // namespace mindspore

View File

@ -1,74 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/argmin.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int ArgMin::GetAxis() const { return this->primitive->value.AsArgMin()->axis; }
bool ArgMin::GetOutMaxValue() const { return this->primitive->value.AsArgMin()->outMaxValue; }
int ArgMin::GetTopK() const { return this->primitive->value.AsArgMin()->topK; }
bool ArgMin::GetKeepDims() const { return this->primitive->value.AsArgMin()->keepDims; }
int ArgMin::GetAxisType() const { return this->primitive->value.AsArgMin()->axisType; }
void ArgMin::SetAxis(int axis) { this->primitive->value.AsArgMin()->axis = axis; }
void ArgMin::SetOutMaxValue(bool out_max_value) { this->primitive->value.AsArgMin()->outMaxValue = out_max_value; }
void ArgMin::SetTopK(int top_k) { this->primitive->value.AsArgMin()->topK = top_k; }
void ArgMin::SetKeepDims(bool keep_dims) { this->primitive->value.AsArgMin()->keepDims = keep_dims; }
void ArgMin::SetAxisType(int axis_type) { this->primitive->value.AsArgMin()->axisType = axis_type; }
#else
int ArgMin::GetAxis() const { return this->primitive->value_as_ArgMin()->axis(); }
bool ArgMin::GetOutMaxValue() const { return this->primitive->value_as_ArgMin()->outMaxValue(); }
int ArgMin::GetTopK() const { return this->primitive->value_as_ArgMin()->topK(); }
bool ArgMin::GetKeepDims() const { return this->primitive->value_as_ArgMin()->keepDims(); }
int ArgMin::GetAxisType() const { return this->primitive->value_as_ArgMin()->axisType(); }
void ArgMin::SetAxis(int axis) {}
void ArgMin::SetOutMaxValue(bool out_max_value) {}
void ArgMin::SetTopK(int top_k) {}
void ArgMin::SetKeepDims(bool keep_dims) {}
void ArgMin::SetAxisType(int axis_type) {}
#endif
int ArgMin::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "tensor number is error.";
}
auto input_shape_size = input->shape().size();
int axis = GetAxis() < 0 ? GetAxis() + input_shape_size : GetAxis();
if (axis >= input_shape_size || axis < 0) {
MS_LOG(ERROR) << "Invalid axis " << GetAxis() << ", input shape size: " << input_shape_size;
return 1;
}
std::vector<int> output_shape(input->shape());
if (GetTopK() == 1 && !GetKeepDims()) {
output_shape.erase(output_shape.begin() + axis);
} else {
output_shape[axis] = GetTopK();
}
output->SetFormat(input->GetFormat());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
return 0;
}
} // namespace mindspore

View File

@ -1,99 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/arithmetic.h"
namespace mindspore {
int Arithmetic::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
if (inputs_.size() != kDoubleNum) {
MS_LOG(ERROR) << "The number of input must be " << kDoubleNum;
return 1;
}
if (outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "The number of output must be " << kSingleNum;
return 1;
}
auto input0 = inputs_[0];
MS_ASSERT(input0 != nullptr);
auto input1 = inputs_[1];
MS_ASSERT(input1 != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
auto input_shape0 = input0->shape();
auto input_shape1 = input1->shape();
auto format = input0->GetFormat();
in_shape0_.resize(5);
in_shape1_.resize(5);
out_shape_.resize(5);
ndim_ = input_shape0.size();
if (input_shape0.size() < input_shape1.size()) {
ndim_ = input_shape1.size();
auto fill_dim_num = input_shape1.size() - input_shape0.size();
int j = 0;
for (int i = 0; i < input_shape1.size(); i++) {
if (i < fill_dim_num) {
in_shape0_[i] = 1;
} else {
in_shape0_[i] = input_shape0[j++];
}
in_shape1_[i] = input_shape1[i];
}
format = input0->GetFormat();
} else if (input_shape0.size() > input_shape1.size()) {
ndim_ = input_shape0.size();
auto fill_dim_num = input_shape0.size() - input_shape1.size();
int j = 0;
for (int i = 0; i < input_shape0.size(); i++) {
if (i < fill_dim_num) {
in_shape1_[i] = 1;
} else {
in_shape1_[i] = input_shape1[j++];
}
in_shape0_[i] = input_shape0[i];
}
} else {
for (int i = 0; i < input_shape0.size(); i++) {
in_shape1_[i] = input_shape1[i];
in_shape0_[i] = input_shape0[i];
}
}
std::vector<int> output_shape;
for (size_t i = 0; i < ndim_; i++) {
if (in_shape0_[i] != in_shape1_[i]) {
if (in_shape0_[i] == 1) {
out_shape_[i] = in_shape1_[i];
} else if (in_shape1_[i] == 1) {
out_shape_[i] = in_shape0_[i];
} else {
MS_LOG(ERROR) << "shapes of input tensors can not be broadCasted";
return -1;
}
broadcasting_ = true;
} else {
out_shape_[i] = in_shape0_[i];
}
output_shape.push_back(out_shape_[i]);
}
output->SetFormat(format);
output->set_shape(output_shape);
output->set_data_type(input0->data_type());
return 0;
}
} // namespace mindspore

View File

@ -1,34 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/arithmetic_self.h"
namespace mindspore {
int ArithmeticSelf::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->SetFormat(input->GetFormat());
output->set_shape(input->shape());
output->set_data_type(input->data_type());
return 0;
}
} // namespace mindspore

View File

@ -1,114 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/batch_to_space.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> BatchToSpace::GetBlockShape() const { return this->primitive->value.AsBatchToSpace()->blockShape; }
std::vector<int> BatchToSpace::GetCrops() const { return this->primitive->value.AsBatchToSpace()->crops; }
void BatchToSpace::SetBlockShape(const std::vector<int> &block_shape) {
this->primitive->value.AsBatchToSpace()->blockShape = block_shape;
}
void BatchToSpace::SetCrops(const std::vector<int> &crops) { this->primitive->value.AsBatchToSpace()->crops = crops; }
#else
std::vector<int> BatchToSpace::GetBlockShape() const {
auto fb_vector = this->primitive->value_as_BatchToSpace()->blockShape();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
std::vector<int> BatchToSpace::GetCrops() const {
auto fb_vector = this->primitive->value_as_BatchToSpace()->crops();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void BatchToSpace::SetBlockShape(const std::vector<int> &block_shape) {}
void BatchToSpace::SetCrops(const std::vector<int> &crops) {}
#endif
namespace {
constexpr int kBatchToSpaceOutputNum = 1;
constexpr int kBatchToSpaceInputNum = 1;
constexpr int kBlockShapeSize = 2;
constexpr int kCropsSize = 4;
} // namespace
int BatchToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
if (outputs.size() != kBatchToSpaceOutputNum || inputs.size() != kBatchToSpaceInputNum) {
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
return 1;
}
auto input = inputs.at(0);
if (input->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "batch_to_space only support NHWC now!";
return 1;
}
auto input_shape = input->shape();
if (input_shape.size() != kDimension_4d) {
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
return 1;
}
auto block_shape = GetBlockShape();
if (block_shape.size() != kBlockShapeSize) {
MS_LOG(ERROR) << "Block shape size should be " << kBlockShapeSize;
return 1;
}
auto crops = GetCrops();
if (crops.size() != kCropsSize) {
MS_LOG(ERROR) << "Crops size should be " << kCropsSize;
return 1;
}
size_t mul_block_shape = 1;
for (size_t i = 0; i < kBlockShapeSize; ++i) {
if (block_shape[i] <= 0) {
MS_LOG(ERROR) << "Input block_shape should > 0!";
return 1;
}
if (input_shape[NHWC_N] % block_shape[i]) {
MS_LOG(ERROR) << "Dimension n " << input_shape[NHWC_N] << " can not divide block_shape[" << i << "] "
<< block_shape[i];
return 1;
}
mul_block_shape *= block_shape[i];
}
if (input_shape[NHWC_N] < mul_block_shape) {
MS_LOG(ERROR) << "Dimension n " << input_shape[NHWC_N] << " < product of block shape!";
return 1;
}
for (size_t i = 0; i < kCropsSize; ++i) {
if (crops[i] < 0) {
MS_LOG(ERROR) << "Input crops should >= 0";
return 1;
}
}
std::vector<int32_t> output_shape(input_shape.size());
output_shape[NHWC_N] = input_shape[NHWC_N] / mul_block_shape;
output_shape[NHWC_H] = input_shape[NHWC_H] * block_shape[0] - crops[0] - crops[1];
output_shape[NHWC_W] = input_shape[NHWC_W] * block_shape[1] - crops[2] - crops[3];
output_shape[NHWC_C] = input_shape[NHWC_C];
outputs[0]->SetFormat(input->GetFormat());
outputs[0]->set_shape(output_shape);
outputs[0]->set_data_type(input->data_type());
return 0;
}
} // namespace mindspore

View File

@ -1,79 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/broadcast_to.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> BroadcastTo::GetDstShape() const { return this->primitive->value.AsBroadcastTo()->dst_shape; }
void BroadcastTo::SetDstShape(const std::vector<int> &dst_shape) {
this->primitive->value.AsBroadcastTo()->dst_shape = dst_shape;
}
#else
std::vector<int> BroadcastTo::GetDstShape() const {
auto fb_vector = this->primitive->value_as_BroadcastTo()->dst_shape();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void BroadcastTo::SetDstShape(const std::vector<int> &dst_shape) {}
#endif
namespace {
constexpr int kBroadcastToInputNum = 1;
constexpr int kBroadcastToOutputNum = 1;
} // namespace
int BroadcastTo::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
if (inputs.size() != kBroadcastToInputNum || outputs.size() != kBroadcastToOutputNum) {
MS_LOG(ERROR) << "input size:" << inputs.size() << ", output size:" << outputs.size();
return 1;
}
auto input = inputs.at(0);
std::vector<int32_t> dst_shape(this->primitive->value_as_BroadcastTo()->dst_shape()->begin(),
this->primitive->value_as_BroadcastTo()->dst_shape()->end());
auto input_shape = input->shape();
std::vector<int> shape(dst_shape.size());
int input_shape_index = input_shape.size() - 1;
if (input_shape.size() > dst_shape.size()) {
MS_LOG(ERROR) << "input shape size " << input_shape.size() << " should <= broadcast to shape size "
<< dst_shape.size() << "!";
return 1;
}
for (int i = dst_shape.size() - 1; i >= 0; --i) {
if (dst_shape[i] < 0) {
MS_LOG(ERROR) << "shape[" << i << "] = " << dst_shape[i] << " ] should be > 0!";
return 1;
}
if (input_shape_index >= 0) {
auto dim = input_shape[input_shape_index];
if (dim != dst_shape[i] && dim != 1) {
MS_LOG(ERROR) << "Invalid broadcast shape!";
return 1;
}
}
shape[i] = dst_shape[i];
--input_shape_index;
}
outputs[0]->SetFormat(input->GetFormat());
outputs[0]->set_shape(shape);
outputs[0]->set_data_type(input->data_type());
return 0;
}
} // namespace mindspore

View File

@ -1,64 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/cast.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int Cast::GetSrcT() const { return this->primitive->value.AsCast()->srcT; }
int Cast::GetDstT() const { return this->primitive->value.AsCast()->dstT; }
void Cast::SetSrcT(int src_t) { this->primitive->value.AsCast()->srcT = src_t; }
void Cast::SetDstT(int dst_t) { this->primitive->value.AsCast()->dstT = dst_t; }
#else
int Cast::GetSrcT() const { return this->primitive->value_as_Cast()->srcT(); }
int Cast::GetDstT() const { return this->primitive->value_as_Cast()->dstT(); }
void Cast::SetSrcT(int src_t) {}
void Cast::SetDstT(int dst_t) {}
#endif
int Cast::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "tensor number is error.";
return 1;
}
MS_ASSERT(cast_prim != nullptr);
if (input->data_type() != GetSrcT()) {
MS_LOG(ERROR) << "input dataType is error";
return 1;
}
if (kSupportDataType.find(input->data_type()) == kSupportDataType.end()) {
MS_LOG(ERROR) << "Unsupported input data type " << input->data_type();
return 1;
}
if (GetDstT() != kNumberTypeFloat && GetDstT() != kNumberTypeFloat32) {
MS_LOG(ERROR) << "Invalid output datatype " << GetDstT();
return 1;
}
output->SetFormat(input->GetFormat());
output->set_shape(input->shape());
output->set_data_type(input->data_type());
return 0;
}
} // namespace mindspore

View File

@ -1,93 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/concat.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int Concat::GetAxis() const { return this->primitive->value.AsConcat()->axis; }
int Concat::GetN() const { return this->primitive->value.AsConcat()->n; }
void Concat::SetAxis(int axis) { this->primitive->value.AsConcat()->axis = axis; }
void Concat::SetN(int n) { this->primitive->value.AsConcat()->n = n; }
#else
int Concat::GetAxis() const { return this->primitive->value_as_Concat()->axis(); }
int Concat::GetN() const { return this->primitive->value_as_Concat()->n(); }
void Concat::SetAxis(int axis) {}
void Concat::SetN(int n) {}
#endif
namespace {
constexpr int kConcatOutputNum = 1;
}
int Concat::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
if (this->primitive == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr!";
return 1;
}
auto input0 = inputs_.front();
auto output = outputs_.front();
if (outputs_.size() != kConcatOutputNum) {
MS_LOG(ERROR) << "output size is error";
return 1;
}
MS_ASSERT(concat_prim != nullptr);
auto input0_shape = inputs_.at(0)->shape();
int axis = GetAxis() < 0 ? GetAxis() + input0_shape.size() : GetAxis();
if (axis < 0 || axis >= input0_shape.size()) {
MS_LOG(ERROR) << "Invalid axis: " << axis;
return 1;
}
auto input0_shape_without_axis = input0_shape;
input0_shape_without_axis.erase(input0_shape_without_axis.begin() + axis);
auto input0_data_type = inputs_.at(0)->data_type();
schema::Format input0_format = inputs_[0]->GetFormat();
int output_axis_dim = input0_shape.at(axis);
for (size_t i = 1; i < inputs_.size(); ++i) {
if (inputs_.at(i)->data_type() != input0_data_type) {
MS_LOG(ERROR) << "All inputs should have the same data type!";
return 1;
}
if (inputs_.at(i)->GetFormat() != input0_format) {
MS_LOG(ERROR) << "All input format should be the same!";
return 1;
}
auto shape_tmp = inputs_.at(i)->shape();
if (shape_tmp.size() != input0_shape.size()) {
MS_LOG(ERROR) << "All inputs should have the same dim num!";
return 1;
}
auto axis_tmp = shape_tmp[axis];
shape_tmp.erase(shape_tmp.begin() + axis);
if (input0_shape_without_axis != shape_tmp) {
MS_LOG(ERROR) << "Inputs should have the same dim except axis!";
return 1;
}
output_axis_dim += axis_tmp;
}
auto output_shape = input0_shape;
output_shape[axis] = output_axis_dim;
outputs_[0]->set_shape(output_shape);
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,55 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/crop.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
long Crop::GetAxis() const { return this->primitive->value.AsCrop()->axis; }
std::vector<long> Crop::GetOffsets() const { return this->primitive->value.AsCrop()->offsets; }
void Crop::SetAxis(long axis) { this->primitive->value.AsCrop()->axis = axis; }
void Crop::SetOffsets(const std::vector<long> &offsets) { this->primitive->value.AsCrop()->offsets = offsets; }
#else
long Crop::GetAxis() const { return this->primitive->value_as_Crop()->axis(); }
std::vector<long> Crop::GetOffsets() const {
auto fb_vector = this->primitive->value_as_Crop()->offsets();
return std::vector<long>(fb_vector->begin(), fb_vector->end());
}
void Crop::SetAxis(long axis) {}
void Crop::SetOffsets(const std::vector<long> &offsets) {}
#endif
namespace {
constexpr int kCropOutputNum = 1;
constexpr int kCropInputNum = 2;
} // namespace
int Crop::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
if (outputs.size() != kCropOutputNum || inputs.size() != kCropInputNum) {
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
return 1;
}
outputs[0]->set_shape(inputs[1]->shape());
outputs[0]->SetFormat(inputs[0]->GetFormat());
outputs[0]->set_data_type(inputs[0]->data_type());
return 0;
}
} // namespace mindspore

View File

@ -1,75 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/depth_to_space.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int DepthToSpace::GetBlockSize() const { return this->primitive->value.AsDepthToSpace()->blockSize; }
int DepthToSpace::GetFormat() const { return this->primitive->value.AsDepthToSpace()->format; }
void DepthToSpace::SetBlockSize(int block_size) { this->primitive->value.AsDepthToSpace()->blockSize = block_size; }
void DepthToSpace::SetFormat(int format) { this->primitive->value.AsDepthToSpace()->format = format; }
#else
int DepthToSpace::GetBlockSize() const { return this->primitive->value_as_DepthToSpace()->blockSize(); }
int DepthToSpace::GetFormat() const { return this->primitive->value_as_DepthToSpace()->format(); }
void DepthToSpace::SetBlockSize(int block_size) {}
void DepthToSpace::SetFormat(int format) {}
#endif
namespace {
constexpr int kDepthToSpaceOutputNum = 1;
constexpr int kDepthToSpaceInputNum = 1;
} // namespace
int DepthToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
if (outputs.size() != kDepthToSpaceOutputNum || inputs.size() != kDepthToSpaceInputNum) {
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
return 1;
}
auto input = inputs.at(0);
if (input->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "depth_to_space only support NHWC now!";
return 1;
}
auto input_shape = input->shape();
if (input_shape.size() != kDimension_4d) {
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
return 1;
}
int32_t block_size = GetBlockSize();
if (input_shape[NHWC_C] % (block_size * block_size) != 0 || input_shape[NHWC_C] == 0) {
MS_LOG(ERROR) << "input dimension c size " << input_shape[NHWC_C] << " should be mulitple of block_size("
<< block_size << ") * block_size)!";
return 1;
}
std::vector<int32_t> output_shape(input_shape.size());
output_shape[NHWC_N] = input_shape[NHWC_N];
output_shape[NHWC_H] = input_shape[NHWC_H] * block_size;
output_shape[NHWC_W] = input_shape[NHWC_W] * block_size;
output_shape[NHWC_C] = input_shape[NHWC_C] / (block_size * block_size);
outputs[0]->set_shape(output_shape);
outputs[0]->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,72 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/embedding_lookup.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
float EmbeddingLookup::GetMaxNorm() const { return this->primitive->value.AsEmbeddingLookup()->maxNorm; }
void EmbeddingLookup::SetMaxNorm(float max_norm) { this->primitive->value.AsEmbeddingLookup()->maxNorm = max_norm; }
#else
float EmbeddingLookup::GetMaxNorm() const { return this->primitive->value_as_EmbeddingLookup()->maxNorm(); }
void EmbeddingLookup::SetMaxNorm(float max_norm) {}
#endif
int EmbeddingLookup::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
if (inputs_.size() < kDoubleNum) {
MS_LOG(ERROR) << "Embedding Lookup should have at least two inputs";
return 1;
}
if (outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "Embedding Lookup should have one outputs";
return 1;
}
auto params_ = inputs_.front();
MS_ASSERT(params_ != nullptr);
auto ids = inputs_.back();
MS_ASSERT(ids != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
auto embedding_shape = params_->shape();
embedding_shape.erase(embedding_shape.begin());
std::vector<int> output_shape(ids->shape());
for (size_t i = 0; i < embedding_shape.size(); ++i) {
output_shape.push_back(embedding_shape.at(i));
}
for (int i = 1; i < inputs_.size() - 1; ++i) {
auto embedding_shape_t = inputs_.at(i)->shape();
embedding_shape_t.erase(embedding_shape_t.begin());
if (embedding_shape_t != embedding_shape) {
MS_LOG(ERROR) << "The embedded layers should have the same shape";
return 1;
}
}
output->set_shape(output_shape);
output->set_data_type(params_->data_type());
return 0;
}
} // namespace mindspore

View File

@ -1,60 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/expand_dims.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int ExpandDims::GetDim() const { return this->primitive->value.AsExpandDims()->dim; }
void ExpandDims::SetDim(int dim) { this->primitive->value.AsExpandDims()->dim = dim; }
#else
int ExpandDims::GetDim() const { return this->primitive->value_as_ExpandDims()->dim(); }
void ExpandDims::SetDim(int dim) {}
#endif
int ExpandDims::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
if (inputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "input size is invalid";
}
if (outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "output size is invalid";
}
int dim = GetDim();
if (dim < 0) {
dim += input->shape().size() + 1;
}
if (dim > input->shape().size()) {
MS_LOG(ERROR) << "attribute dim out of range";
return 1;
}
auto out_shape = input->shape();
out_shape.insert(out_shape.begin() + dim, 1, 1);
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,56 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/fill.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> Fill::GetDims() const { return this->primitive->value.AsFill()->dims; }
void Fill::SetDims(const std::vector<int> &dims) { this->primitive->value.AsFill()->dims = dims; }
#else
std::vector<int> Fill::GetDims() const {
auto fb_vector = this->primitive->value_as_Fill()->dims();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void Fill::SetDims(const std::vector<int> &dims) {}
#endif
int Fill::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
auto output = outputs_.front();
if (input == nullptr || output == nullptr) {
MS_LOG(ERROR) << "Fill input or output is null!";
return 1;
}
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size();
return 1;
}
std::vector<int> output_shape;
(void)output_shape.insert(output_shape.begin(), GetDims().begin(), GetDims().end());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,47 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/flatten.h"
namespace mindspore {
int Flatten::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
auto output = outputs_.front();
if (input == nullptr || output == nullptr) {
MS_LOG(ERROR) << "Flatten input or output is null!";
return 1;
}
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size();
return 1;
}
auto input_shape = input->shape();
std::vector<int> output_shape(2);
output_shape[0] = input_shape[0];
output_shape[1] = 1;
for (int i = 1; i < input_shape.size(); i++) {
output_shape[1] *= input_shape[i];
}
output->set_shape(output_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,87 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/gather.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int Gather::GetAxis() const { return this->primitive->value.AsGather()->axis; }
int Gather::GetBatchDims() const { return this->primitive->value.AsGather()->batchDims; }
void Gather::SetAxis(int axis) { this->primitive->value.AsGather()->axis = axis; }
void Gather::SetBatchDims(int batch_dims) { this->primitive->value.AsGather()->batchDims = batch_dims; }
#else
int Gather::GetAxis() const { return this->primitive->value_as_Gather()->axis(); }
int Gather::GetBatchDims() const { return this->primitive->value_as_Gather()->batchDims(); }
void Gather::SetAxis(int axis) {}
void Gather::SetBatchDims(int batch_dims) {}
#endif
int Gather::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
if (inputs_.size() != kDoubleNum) {
MS_LOG(ERROR) << "Gather should have two inputs";
return 1;
}
if (outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "Gather should have one outputs";
return 1;
}
auto input = inputs_.at(0);
MS_ASSERT(input != nullptr);
auto indices = inputs_.at(1);
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(input != nullptr);
int axis = GetAxis();
int batch_dims = GetBatchDims();
if (axis < 0) {
axis += input->shape().size();
}
auto indices_shape = indices->shape();
int indices_rank = indices_shape.size();
if (indices_rank < batch_dims + 1) {
MS_LOG(ERROR) << "input[1]'s rank is less than batchDim + 1";
return 1;
}
if (batch_dims != 0) {
MS_LOG(ERROR) << "batchDims " << batch_dims << " != 0, which is not support";
return 1;
}
auto in_shape = input->shape();
int in_rank = in_shape.size();
if (in_rank < axis + 1) {
MS_LOG(ERROR) << "input[0]'s rank is less than axis + 1";
return 1;
}
std::vector<int> out_shape{in_shape};
out_shape.erase(out_shape.begin() + axis);
for (size_t i = 0; i < indices_rank; i++) {
out_shape.insert(out_shape.begin() + axis, indices_shape[i]);
}
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,74 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/gather_nd.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int GatherNd::GetBatchDims() const { return this->primitive->value.AsGatherNd()->batchDims; }
void GatherNd::SetBatchDims(int batch_dims) { this->primitive->value.AsGatherNd()->batchDims = batch_dims; }
#else
int GatherNd::GetBatchDims() const { return this->primitive->value_as_GatherNd()->batchDims(); }
void GatherNd::SetBatchDims(int batch_dims) {}
#endif
int GatherNd::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
if (inputs_.size() != kDoubleNum) {
MS_LOG(ERROR) << "GatherNd should have two inputs";
return 1;
}
if (outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "GatherNd should have one outputs";
return 1;
}
auto input = inputs_.at(0);
MS_ASSERT(input != nullptr);
auto indices = inputs_.at(1);
MS_ASSERT(indices != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
auto in_shape = input->shape();
int in_rank = in_shape.size();
auto indices_shape = indices->shape();
int indices_rank = indices_shape.size();
if (indices_shape[indices_rank - 1] > in_rank) {
MS_LOG(ERROR) << "Input of indices data is error!";
return 1;
}
std::vector<int> out_shape;
int i = 0;
for (i = 0; i < indices_rank - 1; ++i) {
out_shape.emplace_back(indices_shape[i]);
}
for (i = indices_shape[indices_rank - 1]; i < in_rank; ++i) {
out_shape.emplace_back(in_shape[i]);
}
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,77 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/lstm.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
bool Lstm::GetBidirection() const { return this->primitive->value.AsLstm()->bidirection; }
void Lstm::SetBidirection(bool bidirection) { this->primitive->value.AsLstm()->bidirection = bidirection; }
#else
bool Lstm::GetBidirection() const { return this->primitive->value_as_Lstm()->bidirection(); }
void Lstm::SetBidirection(bool bidirection) {}
#endif
const int kLstmInputNum = 6;
const int kLstmOutputNum = 3;
int Lstm::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
if (inputs_.size() != kLstmInputNum || outputs_.size() != kLstmOutputNum) {
MS_LOG(ERROR) << "OpLstm inputs or outputs size error.";
return 1;
}
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto weight_i = inputs_.front();
MS_ASSERT(input0 != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
std::vector<int> in_shape = input->shape();
std::vector<int> w_shape = weight_i->shape(); // layer, hidden_size * 4, input_size
if (in_shape.size() != 3 || w_shape.size() != 3) {
MS_LOG(ERROR) << "OpLstm input dims should be 3.";
return 1;
}
int hidden_size = w_shape[1] / 4;
// set output
std::vector<int> out_shape(in_shape);
out_shape[2] = hidden_size;
if (GetBidirection()) {
out_shape.insert(out_shape.begin() + 1, 2);
}
output->set_shape(out_shape);
// set hidden state, cell state
std::vector<int> state_shape(in_shape);
state_shape[0] = GetBidirection() ? 2 : 1;
state_shape[2] = hidden_size;
outputs_[1]->set_shape(state_shape);
outputs_[2]->set_shape(state_shape);
for (int i = 0; i < kLstmOutputNum; i++) {
outputs_[i]->set_data_type(input->data_type());
outputs_[i]->SetFormat(input->GetFormat());
}
return 0;
}
} // namespace mindspore

View File

@ -1,77 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/matmul.h"
#include <utility>
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
bool MatMul::GetTransposeA() const { return this->primitive->value.AsMatMul()->transposeA; }
bool MatMul::GetTransposeB() const { return this->primitive->value.AsMatMul()->transposeB; }
void MatMul::SetTransposeA(bool transpose_a) { this->primitive->value.AsMatMul()->transposeA = transpose_a; }
void MatMul::SetTransposeB(bool transpose_b) { this->primitive->value.AsMatMul()->transposeB = transpose_b; }
#else
bool MatMul::GetTransposeA() const { return this->primitive->value_as_MatMul()->transposeA(); }
bool MatMul::GetTransposeB() const { return this->primitive->value_as_MatMul()->transposeB(); }
void MatMul::SetTransposeA(bool transpose_a) {}
void MatMul::SetTransposeB(bool transpose_b) {}
#endif
int MatMul::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
if (inputs_.size() != kDoubleNum) {
MS_LOG(ERROR) << "OpMatMul inputs size: " << inputs_.size();
return 1;
}
auto input0 = inputs_.front();
MS_ASSERT(input0 != nullptr);
auto input1 = inputs_.at(1);
MS_ASSERT(input1 != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
std::vector<int> a_shape = input0->shape();
std::vector<int> b_shape = input1->shape();
if (a_shape.size() < 2 || b_shape.size() < 2) {
MS_LOG(ERROR) << "inputs shape is invalid";
return 1;
}
for (int i = 0; i < a_shape.size() - 2; ++i) {
if (a_shape[i] != b_shape[i]) {
MS_LOG(ERROR) << "Op MatMul's dimensions must be equal";
return 1;
}
}
if (GetTransposeA()) {
std::swap(a_shape[a_shape.size() - 1], a_shape[a_shape.size() - 2]);
}
if (GetTransposeB()) {
std::swap(b_shape[b_shape.size() - 1], b_shape[b_shape.size() - 2]);
}
std::vector<int> c_shape(a_shape);
c_shape[c_shape.size() - 1] = b_shape[b_shape.size() - 1];
output->set_shape(c_shape);
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,94 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/mean.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> Mean::GetAxis() const { return this->primitive->value.AsMean()->axis; }
bool Mean::GetKeepDims() const { return this->primitive->value.AsMean()->keepDims; }
void Mean::SetAxis(const std::vector<int> &axis) { this->primitive->value.AsMean()->axis = axis; }
void Mean::SetKeepDims(bool keep_dims) { this->primitive->value.AsMean()->keepDims = keep_dims; }
#else
std::vector<int> Mean::GetAxis() const {
auto fb_vector = this->primitive->value_as_Mean()->axis();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
bool Mean::GetKeepDims() const { return this->primitive->value_as_Mean()->keepDims(); }
void Mean::SetAxis(const std::vector<int> &axis) {}
void Mean::SetKeepDims(bool keep_dims) {}
#endif
namespace {
constexpr size_t kInputSize = 1;
constexpr size_t kOutputSize = 1;
} // namespace
int Mean::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
if (inputs_.size() != kInputSize || outputs_.size() != kOutputSize) {
return 1;
}
auto input = inputs_.front();
auto output = outputs_.front();
if (input == nullptr || output == nullptr) {
return 1;
}
if (this->primitive == nullptr) {
return 1;
}
bool keep_dims = static_cast<bool>(GetKeepDims());
std::vector<int> in_shape = input->shape();
std::vector<int> out_shape;
const auto &axes = GetAxis();
auto num_axes = axes.size();
// reduce on all axes
if (num_axes == 0) {
if (keep_dims) {
for (auto i = 0; i < in_shape.size(); i++) {
out_shape.push_back(1);
}
}
output->set_shape(out_shape);
output->set_data_type(input->data_type());
return 0;
}
// reduce on selected axes
for (size_t i = 0; i < in_shape.size(); i++) {
bool reduce_axis = false;
for (int idx = 0; idx < num_axes; ++idx) {
if (static_cast<size_t>(axes[idx]) == i) {
reduce_axis = true;
break;
}
}
if (reduce_axis) {
if (keep_dims) {
out_shape.push_back(1);
}
} else {
out_shape.push_back(in_shape[i]);
}
}
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,41 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/nchw2nhwc.h"
namespace mindspore {
int Nchw2Nhwc::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
std::vector<int> nchw_shape = input->shape();
if (nchw_shape.size() != 4) {
output->set_shape(nchw_shape);
} else {
std::vector<int> nhwc_shape{nchw_shape};
nhwc_shape[NHWC_N] = nchw_shape[NCHW_N];
nhwc_shape[NHWC_H] = nchw_shape[NCHW_H];
nhwc_shape[NHWC_W] = nchw_shape[NCHW_W];
nhwc_shape[NHWC_C] = nchw_shape[NCHW_C];
output->set_shape(nhwc_shape);
}
output->SetFormat(schema::Format_NHWC);
output->set_data_type(input->data_type());
return 0;
}
} // namespace mindspore

View File

@ -1,41 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/nhwc2nchw.h"
namespace mindspore {
int Nhwc2Nchw::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
std::vector<int> nhwc_shape = input->shape();
if (nhwc_shape.size() != 4) {
output->set_shape(nhwc_shape);
} else {
std::vector<int> nchw_shape{nhwc_shape};
nchw_shape[NCHW_N] = nhwc_shape[NHWC_N];
nchw_shape[NCHW_C] = nhwc_shape[NHWC_C];
nchw_shape[NCHW_H] = nhwc_shape[NHWC_H];
nchw_shape[NCHW_W] = nhwc_shape[NHWC_W];
output->set_shape(nchw_shape);
}
output->SetFormat(schema::Format_NCHW);
output->set_data_type(input->data_type());
return 0;
}
} // namespace mindspore

View File

@ -1,79 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/one_hot.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int OneHot::GetAxis() const { return this->primitive->value.AsOneHot()->axis; }
void OneHot::SetAxis(int axis) { this->primitive->value.AsOneHot()->axis = axis; }
#else
int OneHot::GetAxis() const { return this->primitive->value_as_OneHot()->axis(); }
void OneHot::SetAxis(int axis) {}
#endif
namespace {
constexpr size_t kOneHotInputNum = 4;
}
int OneHot::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
if (this->primitive == nullptr) {
return 1;
}
int axis = GetAxis();
// indices, depth, on_value, off_value
if (inputs.size() != kOneHotInputNum) {
MS_LOG(ERROR) << "OneHot got inputs num " << inputs.size() << ", should be " << kOneHotInputNum;
return 1;
}
auto depth_tensor = inputs.at(1);
if (depth_tensor == nullptr) {
return 1;
}
const int *depth = static_cast<int *>(depth_tensor->Data());
auto input = inputs.front();
if (input == nullptr) {
return 1;
}
const auto input_shape = input->shape();
int input_rank = static_cast<int>(input_shape.size());
if (axis < 0) {
axis += input_rank + 1;
}
std::vector<int> output_shape(input_shape);
output_shape.insert(output_shape.cbegin() + axis, *depth);
auto output = outputs.front();
if (output == nullptr) {
return 1;
}
output->set_shape(output_shape);
auto on_value = inputs.at(2);
if (on_value == nullptr) {
return 1;
}
output->set_data_type(on_value->data_type());
output->SetFormat(on_value->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,76 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/pad.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> Pad::GetPaddings() const { return this->primitive->value.AsPad()->paddings; }
int Pad::GetPaddingMode() const { return this->primitive->value.AsPad()->paddingMode; }
float Pad::GetConstantValue() const { return this->primitive->value.AsPad()->constantValue; }
void Pad::SetPaddings(const std::vector<int> &paddings) { this->primitive->value.AsPad()->paddings = paddings; }
void Pad::SetPaddingMode(int padding_mode) { this->primitive->value.AsPad()->paddingMode = padding_mode; }
void Pad::SetConstantValue(float constant_value) { this->primitive->value.AsPad()->constantValue = constant_value; }
#else
std::vector<int> Pad::GetPaddings() const {
auto fb_vector = this->primitive->value_as_Pad()->paddings();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
int Pad::GetPaddingMode() const { return this->primitive->value_as_Pad()->paddingMode(); }
float Pad::GetConstantValue() const { return this->primitive->value_as_Pad()->constantValue(); }
void Pad::SetPaddings(const std::vector<int> &paddings) {}
void Pad::SetPaddingMode(int padding_mode) {}
void Pad::SetConstantValue(float constant_value) {}
#endif
namespace {
const size_t kPaddingsSize = 8;
const size_t kInputRank = 4;
} // namespace
int Pad::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
if (this->primitive == nullptr) {
return 1;
}
auto paddings = GetPaddings();
auto input = inputs.front();
if (input == nullptr) {
return 1;
}
auto input_shape = input->shape();
std::vector<int> output_shape;
MS_ASSERT(input->shape().size() <= kInputRank);
for (size_t i = 0; i < input_shape.size(); i++) {
auto paddings_index = i + kInputRank - input_shape.size();
auto shape = input_shape[i] + (paddings)[2 * paddings_index] + (paddings)[2 * paddings_index + 1];
output_shape.push_back(shape);
}
auto output = outputs.front();
if (output == nullptr) {
return 1;
}
output->SetFormat(input->GetFormat());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
return 0;
}
} // namespace mindspore

View File

@ -1,139 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/pooling.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int Pooling::GetFormat() const { return this->primitive->value.AsPooling()->format; }
int Pooling::GetPoolingMode() const { return this->primitive->value.AsPooling()->poolingMode; }
bool Pooling::GetGlobal() const { return this->primitive->value.AsPooling()->global; }
int Pooling::GetWindowW() const { return this->primitive->value.AsPooling()->windowW; }
int Pooling::GetWindowH() const { return this->primitive->value.AsPooling()->windowH; }
int Pooling::GetStrideW() const { return this->primitive->value.AsPooling()->strideW; }
int Pooling::GetStrideH() const { return this->primitive->value.AsPooling()->strideH; }
int Pooling::GetPadMode() const { return this->primitive->value.AsPooling()->padMode; }
int Pooling::GetPadUp() const { return this->primitive->value.AsPooling()->padUp; }
int Pooling::GetPadDown() const { return this->primitive->value.AsPooling()->padDown; }
int Pooling::GetPadLeft() const { return this->primitive->value.AsPooling()->padLeft; }
int Pooling::GetPadRight() const { return this->primitive->value.AsPooling()->padRight; }
int Pooling::GetRoundMode() const { return this->primitive->value.AsPooling()->roundMode; }
void Pooling::SetFormat(int format) { this->primitive->value.AsPooling()->format = (schema::Format)format; }
void Pooling::SetPoolingMode(int pooling_mode) {
this->primitive->value.AsPooling()->poolingMode = (schema::PoolMode)pooling_mode;
}
void Pooling::SetGlobal(bool global) { this->primitive->value.AsPooling()->global = global; }
void Pooling::SetWindowW(int window_w) { this->primitive->value.AsPooling()->windowW = window_w; }
void Pooling::SetWindowH(int window_h) { this->primitive->value.AsPooling()->windowH = window_h; }
void Pooling::SetStrideW(int stride_w) { this->primitive->value.AsPooling()->strideW = stride_w; }
void Pooling::SetStrideH(int stride_h) { this->primitive->value.AsPooling()->strideH = stride_h; }
void Pooling::SetPadMode(int pad_mode) { this->primitive->value.AsPooling()->padMode = (schema::PadMode)pad_mode; }
void Pooling::SetPadUp(int pad_up) { this->primitive->value.AsPooling()->padUp = pad_up; }
void Pooling::SetPadDown(int pad_down) { this->primitive->value.AsPooling()->padDown = pad_down; }
void Pooling::SetPadLeft(int pad_left) { this->primitive->value.AsPooling()->padLeft = pad_left; }
void Pooling::SetPadRight(int pad_right) { this->primitive->value.AsPooling()->padRight = pad_right; }
void Pooling::SetRoundMode(int round_mode) {
this->primitive->value.AsPooling()->roundMode = (schema::RoundMode)round_mode;
}
#else
int Pooling::GetFormat() const { return this->primitive->value_as_Pooling()->format(); }
int Pooling::GetPoolingMode() const { return this->primitive->value_as_Pooling()->poolingMode(); }
bool Pooling::GetGlobal() const { return this->primitive->value_as_Pooling()->global(); }
int Pooling::GetWindowW() const { return this->primitive->value_as_Pooling()->windowW(); }
int Pooling::GetWindowH() const { return this->primitive->value_as_Pooling()->windowH(); }
int Pooling::GetStrideW() const { return this->primitive->value_as_Pooling()->strideW(); }
int Pooling::GetStrideH() const { return this->primitive->value_as_Pooling()->strideH(); }
int Pooling::GetPadMode() const { return this->primitive->value_as_Pooling()->padMode(); }
int Pooling::GetPadUp() const { return this->primitive->value_as_Pooling()->padUp(); }
int Pooling::GetPadDown() const { return this->primitive->value_as_Pooling()->padDown(); }
int Pooling::GetPadLeft() const { return this->primitive->value_as_Pooling()->padLeft(); }
int Pooling::GetPadRight() const { return this->primitive->value_as_Pooling()->padRight(); }
int Pooling::GetRoundMode() const { return this->primitive->value_as_Pooling()->roundMode(); }
void Pooling::SetFormat(int format) {}
void Pooling::SetPoolingMode(int pooling_mode) {}
void Pooling::SetGlobal(bool global) {}
void Pooling::SetWindowW(int window_w) {}
void Pooling::SetWindowH(int window_h) {}
void Pooling::SetStrideW(int stride_w) {}
void Pooling::SetStrideH(int stride_h) {}
void Pooling::SetPadMode(int pad_mode) {}
void Pooling::SetPadUp(int pad_up) {}
void Pooling::SetPadDown(int pad_down) {}
void Pooling::SetPadLeft(int pad_left) {}
void Pooling::SetPadRight(int pad_right) {}
void Pooling::SetRoundMode(int round_mode) {}
#endif
int Pooling::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
int input_h = input->shape().at(1);
int input_w = input->shape().at(2);
MS_ASSERT(pooling_prim != nullptr);
auto window_h = GetWindowH();
auto window_w = GetWindowH();
if (GetGlobal()) {
window_h = input_h;
window_w = input_w;
}
int output_h = 0;
int output_w = 0;
pad_l_ = GetPadLeft();
pad_u_ = GetPadUp();
pad_d_ = GetPadDown();
pad_r_ = GetPadRight();
if ((schema::PadMode)GetPadMode() == schema::PadMode_SAME) {
output_w = std::ceil(static_cast<float>(input_w) / static_cast<float>(GetStrideW()));
output_h = std::ceil(static_cast<float>(input_h) / static_cast<float>(GetStrideH()));
auto pad_h_all = ((output_h - 1) * GetStrideH() + (window_h - 1) + 1 - input_h);
auto pad_w_all = ((output_w - 1) * GetStrideW() + (window_w - 1) + 1 - input_w);
pad_u_ = pad_h_all / 2;
pad_d_ = pad_h_all - pad_u_;
pad_l_ = pad_w_all / 2;
pad_r_ = pad_w_all - pad_l_;
} else {
auto round_mode = GetRoundMode();
if (round_mode == schema::RoundMode_FLOOR) {
output_h = std::floor(static_cast<float>(input_h + pad_u_ + pad_d_ - window_h) / GetStrideH()) + 1;
output_w = std::floor(static_cast<float>(input_w + pad_l_ + pad_r_ - window_w) / GetStrideW()) + 1;
} else if (round_mode == schema::RoundMode_CEIL) {
output_h = std::ceil(static_cast<float>(input_h + pad_u_ + pad_d_ - window_h) / GetStrideH()) + 1;
output_w = std::ceil(static_cast<float>(input_w + pad_l_ + pad_r_ - window_w) / GetStrideW()) + 1;
} else {
MS_LOG(ERROR) << "unsupported round mode.";
}
}
// todo: fmk type
auto input_shape = input->shape();
input_shape.at(1) = output_h;
input_shape.at(2) = output_w;
output->set_shape(input_shape);
output->set_data_type(input->data_type());
// todo: temp fix
output->SetFormat(schema::Format_NHWC);
return 0;
}
} // namespace mindspore

View File

@ -1,62 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/power.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
float Power::GetPower() const { return this->primitive->value.AsPower()->power; }
float Power::GetScale() const { return this->primitive->value.AsPower()->scale; }
float Power::GetShift() const { return this->primitive->value.AsPower()->shift; }
void Power::SetPower(float power) { this->primitive->value.AsPower()->power = power; }
void Power::SetScale(float scale) { this->primitive->value.AsPower()->scale = scale; }
void Power::SetShift(float shift) { this->primitive->value.AsPower()->shift = shift; }
#else
float Power::GetPower() const { return this->primitive->value_as_Power()->power(); }
float Power::GetScale() const { return this->primitive->value_as_Power()->scale(); }
float Power::GetShift() const { return this->primitive->value_as_Power()->shift(); }
void Power::SetPower(float power) {}
void Power::SetScale(float scale) {}
void Power::SetShift(float shift) {}
#endif
int Power::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
auto x_tensor = inputs[0];
MS_ASSERT(x_tensor != nullptr);
lite::tensor::Tensor *exp_tensor = nullptr;
if (inputs.size() == 2) {
exp_tensor = inputs[1];
MS_ASSERT(exp_tensor != nullptr);
}
auto output_tensor = outputs[0];
MS_ASSERT(output_tensor != nullptr);
if (exp_tensor != nullptr) {
if (exp_tensor->shape() != x_tensor->shape() || exp_tensor->data_type() != x_tensor->data_type()) {
MS_LOG(ERROR) << "Power inputs shape or type is not equal!";
return 1;
}
}
output_tensor->SetFormat(x_tensor->GetFormat());
output_tensor->set_shape(x_tensor->shape());
output_tensor->set_data_type(x_tensor->data_type());
return 0;
}
} // namespace mindspore

View File

@ -1,127 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/prior_box.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> PriorBox::GetMinSizes() const { return this->primitive->value.AsPriorBox()->max_sizes; }
std::vector<int> PriorBox::GetMaxSizes() const { return this->primitive->value.AsPriorBox()->max_sizes; }
std::vector<float> PriorBox::GetAspectRatios() const { return this->primitive->value.AsPriorBox()->aspect_ratios; }
std::vector<float> PriorBox::GetVariances() const { return this->primitive->value.AsPriorBox()->variances; }
int PriorBox::GetImageSizeW() const { return this->primitive->value.AsPriorBox()->image_size_w; }
int PriorBox::GetImageSizeH() const { return this->primitive->value.AsPriorBox()->image_size_h; }
float PriorBox::GetStepW() const { return this->primitive->value.AsPriorBox()->step_w; }
float PriorBox::GetStepH() const { return this->primitive->value.AsPriorBox()->step_h; }
bool PriorBox::GetClip() const { return this->primitive->value.AsPriorBox()->clip; }
bool PriorBox::GetFlip() const { return this->primitive->value.AsPriorBox()->flip; }
float PriorBox::GetOffset() const { return this->primitive->value.AsPriorBox()->offset; }
void PriorBox::SetMinSizes(const std::vector<int> &min_sizes) {
this->primitive->value.AsPriorBox()->min_sizes = min_sizes;
}
void PriorBox::SetMaxSizes(const std::vector<int> &max_sizes) {
this->primitive->value.AsPriorBox()->max_sizes = max_sizes;
}
void PriorBox::SetAspectRatios(const std::vector<float> &aspect_ratios) {
this->primitive->value.AsPriorBox()->aspect_ratios = aspect_ratios;
}
void PriorBox::SetVariances(const std::vector<float> &variances) {
this->primitive->value.AsPriorBox()->variances = variances;
}
void PriorBox::SetImageSizeW(int image_size_w) { this->primitive->value.AsPriorBox()->image_size_w = image_size_w; }
void PriorBox::SetImageSizeH(int image_size_h) { this->primitive->value.AsPriorBox()->image_size_h = image_size_h; }
void PriorBox::SetStepW(float step_w) { this->primitive->value.AsPriorBox()->step_w = step_w; }
void PriorBox::SetStepH(float step_h) { this->primitive->value.AsPriorBox()->step_h = step_h; }
void PriorBox::SetClip(bool clip) { this->primitive->value.AsPriorBox()->clip = clip; }
void PriorBox::SetFlip(bool flip) { this->primitive->value.AsPriorBox()->flip = flip; }
void PriorBox::SetOffset(float offset) { this->primitive->value.AsPriorBox()->offset = offset; }
#else
std::vector<int> PriorBox::GetMinSizes() const {
auto fb_vector = this->primitive->value_as_PriorBox()->min_sizes();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
std::vector<int> PriorBox::GetMaxSizes() const {
auto fb_vector = this->primitive->value_as_PriorBox()->max_sizes();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
std::vector<float> PriorBox::GetAspectRatios() const {
auto fb_vector = this->primitive->value_as_PriorBox()->aspect_ratios();
return std::vector<float>(fb_vector->begin(), fb_vector->end());
}
std::vector<float> PriorBox::GetVariances() const {
auto fb_vector = this->primitive->value_as_PriorBox()->variances();
return std::vector<float>(fb_vector->begin(), fb_vector->end());
}
int PriorBox::GetImageSizeW() const { return this->primitive->value_as_PriorBox()->image_size_w(); }
int PriorBox::GetImageSizeH() const { return this->primitive->value_as_PriorBox()->image_size_h(); }
float PriorBox::GetStepW() const { return this->primitive->value_as_PriorBox()->step_w(); }
float PriorBox::GetStepH() const { return this->primitive->value_as_PriorBox()->step_h(); }
bool PriorBox::GetClip() const { return this->primitive->value_as_PriorBox()->clip(); }
bool PriorBox::GetFlip() const { return this->primitive->value_as_PriorBox()->flip(); }
float PriorBox::GetOffset() const { return this->primitive->value_as_PriorBox()->offset(); }
void PriorBox::SetMinSizes(const std::vector<int> &min_sizes) {}
void PriorBox::SetMaxSizes(const std::vector<int> &max_sizes) {}
void PriorBox::SetAspectRatios(const std::vector<float> &aspect_ratios) {}
void PriorBox::SetVariances(const std::vector<float> &variances) {}
void PriorBox::SetImageSizeW(int image_size_w) {}
void PriorBox::SetImageSizeH(int image_size_h) {}
void PriorBox::SetStepW(float step_w) {}
void PriorBox::SetStepH(float step_h) {}
void PriorBox::SetClip(bool clip) {}
void PriorBox::SetFlip(bool flip) {}
void PriorBox::SetOffset(float offset) {}
#endif
namespace {
constexpr int kPriorBoxPoints = 4;
constexpr int kPriorBoxN = 1;
constexpr int kPriorBoxW = 1;
constexpr int kPriorBoxC = 2;
} // namespace
int PriorBox::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
std::vector<float> different_aspect_ratios{1.0f};
auto aspect_ratios = GetAspectRatios();
MS_ASSERT(aspect_ratios != nullptr);
for (auto i = 0; i < aspect_ratios.size(); i++) {
float ratio = (aspect_ratios)[i];
bool exist = std::any_of(different_aspect_ratios.begin(), different_aspect_ratios.end(),
[&](float v) { return abs(ratio - v) < 1e-6; });
if (!exist) {
different_aspect_ratios.emplace_back(ratio);
if (GetFlip()) {
different_aspect_ratios.emplace_back(1.0f / ratio);
}
}
}
int32_t num_priors_box = GetMinSizes().size() * different_aspect_ratios.size() + GetMaxSizes().size();
auto input = inputs_.at(0);
MS_ASSERT(input != nullptr);
int32_t h = input->Height() * input->Width() * num_priors_box * kPriorBoxPoints;
std::vector<int> output_shape{kPriorBoxN, h, kPriorBoxW, kPriorBoxC};
auto output = outputs_.at(0);
MS_ASSERT(output != nullptr);
output->set_shape(output_shape);
output->set_data_type(kNumberTypeFloat32);
output->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,49 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/quant_dtype_cast.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int QuantDTypeCast::GetSrcT() const { return this->primitive->value.AsQuantDTypeCast()->srcT; }
int QuantDTypeCast::GetDstT() const { return this->primitive->value.AsQuantDTypeCast()->dstT; }
void QuantDTypeCast::SetSrcT(int src_t) { this->primitive->value.AsQuantDTypeCast()->srcT = src_t; }
void QuantDTypeCast::SetDstT(int dst_t) { this->primitive->value.AsQuantDTypeCast()->dstT = dst_t; }
#else
int QuantDTypeCast::GetSrcT() const { return this->primitive->value_as_QuantDTypeCast()->srcT(); }
int QuantDTypeCast::GetDstT() const { return this->primitive->value_as_QuantDTypeCast()->dstT(); }
void QuantDTypeCast::SetSrcT(int src_t) {}
void QuantDTypeCast::SetDstT(int dst_t) {}
#endif
int QuantDTypeCast::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
auto param = primitive->value_as_QuantDTypeCast();
MS_ASSERT(input->data_type() == param->srcT);
output->set_data_type(static_cast<TypeId>(param->dstT()));
output->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,59 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/range.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int Range::GetDType() const { return this->primitive->value.AsRange()->dType; }
int Range::GetStart() const { return this->primitive->value.AsRange()->start; }
int Range::GetLimit() const { return this->primitive->value.AsRange()->limit; }
int Range::GetDelta() const { return this->primitive->value.AsRange()->delta; }
void Range::SetDType(int d_type) { this->primitive->value.AsRange()->dType = d_type; }
void Range::SetStart(int start) { this->primitive->value.AsRange()->start = start; }
void Range::SetLimit(int limit) { this->primitive->value.AsRange()->limit = limit; }
void Range::SetDelta(int delta) { this->primitive->value.AsRange()->delta = delta; }
#else
int Range::GetDType() const { return this->primitive->value_as_Range()->dType(); }
int Range::GetStart() const { return this->primitive->value_as_Range()->start(); }
int Range::GetLimit() const { return this->primitive->value_as_Range()->limit(); }
int Range::GetDelta() const { return this->primitive->value_as_Range()->delta(); }
void Range::SetDType(int d_type) {}
void Range::SetStart(int start) {}
void Range::SetLimit(int limit) {}
void Range::SetDelta(int delta) {}
#endif
int Range::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
int shape_size = std::ceil(static_cast<float>(GetLimit() - GetStart()) / GetDelta());
std::vector<int> in_shape(1);
in_shape.push_back(shape_size);
output->set_shape(in_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,33 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/rank.h"
namespace mindspore {
int Rank::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
std::vector<int> in_shape(1, 1);
output->set_shape(in_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,99 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/reduce.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> Reduce::GetAxes() const { return this->primitive->value.AsReduce()->axes; }
int Reduce::GetKeepDims() const { return this->primitive->value.AsReduce()->keepDims; }
int Reduce::GetMode() const { return this->primitive->value.AsReduce()->mode; }
void Reduce::SetAxes(const std::vector<int> &axes) { this->primitive->value.AsReduce()->axes = axes; }
void Reduce::SetKeepDims(int keep_dims) { this->primitive->value.AsReduce()->keepDims = keep_dims; }
void Reduce::SetMode(int mode) { this->primitive->value.AsReduce()->mode = (schema::ReduceMode)mode; }
#else
std::vector<int> Reduce::GetAxes() const {
auto fb_vector = this->primitive->value_as_Reduce()->axes();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
int Reduce::GetKeepDims() const { return this->primitive->value_as_Reduce()->keepDims(); }
int Reduce::GetMode() const { return this->primitive->value_as_Reduce()->mode(); }
void Reduce::SetAxes(const std::vector<int> &axes) {}
void Reduce::SetKeepDims(int keep_dims) {}
void Reduce::SetMode(int mode) {}
#endif
namespace {
constexpr size_t kInputSize = 1;
constexpr size_t kOutputSize = 1;
} // namespace
int Reduce::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
if (inputs_.size() != kInputSize || outputs_.size() != kOutputSize) {
return 1;
}
auto input = inputs_.front();
auto output = outputs_.front();
if (input == nullptr || output == nullptr) {
return 1;
}
if (this->primitive == nullptr) {
return 1;
}
bool keep_dims = static_cast<bool>(GetKeepDims());
std::vector<int> in_shape = input->shape();
std::vector<int> out_shape;
const auto &axes = GetAxes();
auto num_axes = axes.size();
// reduce on all axes
if (num_axes == 0) {
if (keep_dims) {
for (auto i = 0; i < in_shape.size(); i++) {
out_shape.push_back(1);
}
}
output->set_shape(out_shape);
output->set_data_type(input->data_type());
return 0;
}
// reduce on selected axes
for (size_t i = 0; i < in_shape.size(); i++) {
bool reduce_axis = false;
for (int idx = 0; idx < num_axes; ++idx) {
if (static_cast<size_t>((axes)[idx]) == i) {
reduce_axis = true;
break;
}
}
if (reduce_axis) {
if (keep_dims) {
out_shape.push_back(1);
}
} else {
out_shape.push_back(in_shape[i]);
}
}
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,153 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/reshape.h"
#include <algorithm>
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int Reshape::GetFormat() const { return this->primitive->value.AsReshape()->format; }
std::vector<long> Reshape::GetShape() const { return this->primitive->value.AsReshape()->shape; }
void Reshape::SetFormat(int format) { this->primitive->value.AsReshape()->format = format; }
void Reshape::SetShape(const std::vector<long> &shape) { this->primitive->value.AsReshape()->shape = shape; }
#else
int Reshape::GetFormat() const { return this->primitive->value_as_Reshape()->format(); }
std::vector<long> Reshape::GetShape() const {
auto fb_vector = this->primitive->value_as_Reshape()->shape();
return std::vector<long>(fb_vector->begin(), fb_vector->end());
}
void Reshape::SetFormat(int format) {}
void Reshape::SetShape(const std::vector<long> &shape) {}
#endif
int Reshape::CalNewShape(const lite::tensor::Tensor *in_tensor, std::vector<int> *out_shape) const {
size_t in_shape_size = 1;
for (size_t i = 0; i < in_tensor->shape().size(); i++) {
in_shape_size *= in_tensor->shape()[i];
}
int64_t inferIndex = -1;
size_t out_shapeSize = 1;
for (size_t i = 0; i < out_shape->size(); i++) {
if (out_shape->at(i) == -1) {
if (inferIndex == -1) {
inferIndex = i;
} else {
MS_LOG(ERROR) << "output shape should has no more than one dim which need infer";
return 1;
}
} else if (out_shape->at(i) < 0) {
MS_LOG(ERROR) << "output shape dim should be non-negative";
return 1;
} else if (out_shape->at(i) == 0) {
out_shape->at(i) = in_tensor->shape().at(i);
out_shapeSize *= out_shape->at(i);
} else {
out_shapeSize *= out_shape->at(i);
}
}
if (inferIndex == -1 && out_shapeSize != in_shape_size) {
MS_LOG(ERROR) << "output shapeSize: " << out_shapeSize << " should be equal to input shapeSize: " << in_shape_size;
return 1;
}
if (inferIndex != -1) {
out_shape->at(inferIndex) = in_shape_size / out_shapeSize;
}
return 0;
}
template <typename T>
void CalShape(const T *data, const std::vector<lite::tensor::Tensor *> &inputs, std::vector<int> *out_shape,
int shape_size) {
int input_count = inputs[0]->ElementsNum();
int index = 0;
int size = 1;
for (size_t i = 0; i < shape_size; i++) {
if (data[i] == -1) {
index = i;
} else {
size *= data[i];
}
out_shape->push_back(data[i]);
}
if (data[index] == -1) {
(*out_shape)[index] = input_count / size;
}
}
int Reshape::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
std::vector<int> out_shape;
if (inputs_.size() == kDoubleNum) {
auto shape_tensor = inputs_.at(1);
if (shape_tensor->Data() == nullptr) {
MS_LOG(INFO) << "Do infer shape in runtime.";
return 1;
}
size_t shape_size = shape_tensor->ElementsNum();
switch (shape_tensor->data_type()) {
case kNumberTypeInt8: {
auto data = reinterpret_cast<int8_t *>(shape_tensor->Data());
CalShape<int8_t>(data, inputs_, &out_shape, shape_size);
} break;
case kNumberTypeInt32: {
auto data = reinterpret_cast<int32_t *>(shape_tensor->Data());
CalShape<int32_t>(data, inputs_, &out_shape, shape_size);
} break;
case kNumberTypeFloat: {
auto data = reinterpret_cast<float *>(shape_tensor->Data());
CalShape<float>(data, inputs_, &out_shape, shape_size);
} break;
case kNumberTypeUInt32: {
auto data = reinterpret_cast<uint32_t *>(shape_tensor->Data());
CalShape<uint32_t>(data, inputs_, &out_shape, shape_size);
} break;
default: {
MS_LOG(ERROR) << "Reshape weight tensor has unsupported dataType: " << shape_tensor->data_type();
return 1;
}
}
} else if (inputs_.size() == kSingleNum) {
std::copy(GetShape().begin(), GetShape().end(), std::back_inserter(out_shape));
} else {
MS_LOG(ERROR) << "inputs tensor size invalid.";
return 1;
}
auto ret = CalNewShape(inputs_.front(), &out_shape);
if (ret != 0) {
MS_LOG(ERROR) << "CalNewShape error";
return ret;
}
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,82 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/resize.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int Resize::GetFormat() const { return this->primitive->value.AsResize()->format; }
int Resize::GetMethod() const { return this->primitive->value.AsResize()->method; }
long Resize::GetNewHeight() const { return this->primitive->value.AsResize()->newHeight; }
long Resize::GetNewWidth() const { return this->primitive->value.AsResize()->newWidth; }
bool Resize::GetAlignCorners() const { return this->primitive->value.AsResize()->alignCorners; }
bool Resize::GetPreserveAspectRatio() const { return this->primitive->value.AsResize()->preserveAspectRatio; }
void Resize::SetFormat(int format) { this->primitive->value.AsResize()->format = (schema::Format)format; }
void Resize::SetMethod(int method) { this->primitive->value.AsResize()->method = (schema::ResizeMethod)method; }
void Resize::SetNewHeight(long new_height) { this->primitive->value.AsResize()->newHeight = new_height; }
void Resize::SetNewWidth(long new_width) { this->primitive->value.AsResize()->newWidth = new_width; }
void Resize::SetAlignCorners(bool align_corners) { this->primitive->value.AsResize()->alignCorners = align_corners; }
void Resize::SetPreserveAspectRatio(bool preserve_aspect_ratio) {
this->primitive->value.AsResize()->preserveAspectRatio = preserve_aspect_ratio;
}
#else
int Resize::GetFormat() const { return this->primitive->value_as_Resize()->format(); }
int Resize::GetMethod() const { return this->primitive->value_as_Resize()->method(); }
long Resize::GetNewHeight() const { return this->primitive->value_as_Resize()->newHeight(); }
long Resize::GetNewWidth() const { return this->primitive->value_as_Resize()->newWidth(); }
bool Resize::GetAlignCorners() const { return this->primitive->value_as_Resize()->alignCorners(); }
bool Resize::GetPreserveAspectRatio() const { return this->primitive->value_as_Resize()->preserveAspectRatio(); }
void Resize::SetFormat(int format) {}
void Resize::SetMethod(int method) {}
void Resize::SetNewHeight(long new_height) {}
void Resize::SetNewWidth(long new_width) {}
void Resize::SetAlignCorners(bool align_corners) {}
void Resize::SetPreserveAspectRatio(bool preserve_aspect_ratio) {}
#endif
namespace {
constexpr int kInputRank = 4;
} // namespace
int Resize::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
if (input == nullptr) {
return 1;
}
MS_ASSERT(input->shape().size() == kInputRank);
auto output = outputs_.front();
if (output == nullptr) {
return 1;
}
auto new_height = GetNewHeight();
auto new_width = GetNewWidth();
std::vector<int> output_shape;
output_shape.push_back(input->Batch());
output_shape.push_back(new_height);
output_shape.push_back(new_width);
output_shape.push_back(input->Channel());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,60 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/reverse_sequence.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int ReverseSequence::GetSeqAxis() const { return this->primitive->value.AsReverseSequence()->seqAxis; }
int ReverseSequence::GetBatchAxis() const { return this->primitive->value.AsReverseSequence()->batchAxis; }
std::vector<int> ReverseSequence::GetSeqLengths() const {
return this->primitive->value.AsReverseSequence()->seqLengths;
}
void ReverseSequence::SetSeqAxis(int seq_axis) { this->primitive->value.AsReverseSequence()->seqAxis = seq_axis; }
void ReverseSequence::SetBatchAxis(int batch_axis) {
this->primitive->value.AsReverseSequence()->batchAxis = batch_axis;
}
void ReverseSequence::SetSeqLengths(const std::vector<int> &seq_lengths) {
this->primitive->value.AsReverseSequence()->seqLengths = seq_lengths;
}
#else
int ReverseSequence::GetSeqAxis() const { return this->primitive->value_as_ReverseSequence()->seqAxis(); }
int ReverseSequence::GetBatchAxis() const { return this->primitive->value_as_ReverseSequence()->batchAxis(); }
std::vector<int> ReverseSequence::GetSeqLengths() const {
auto fb_vector = this->primitive->value_as_ReverseSequence()->seqLengths();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void ReverseSequence::SetSeqAxis(int seq_axis) {}
void ReverseSequence::SetBatchAxis(int batch_axis) {}
void ReverseSequence::SetSeqLengths(const std::vector<int> &seq_lengths) {}
#endif
int ReverseSequence::InferShape(std::vector<lite::tensor::Tensor *> inputs,
std::vector<lite::tensor::Tensor *> outputs) {
auto input = inputs.front();
auto output = outputs.front();
MS_ASSERT(input != nullptr);
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,75 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/roi_pooling.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int ROIPooling::GetPooledH() const { return this->primitive->value.AsROIPooling()->pooledH; }
int ROIPooling::GetPooledW() const { return this->primitive->value.AsROIPooling()->pooledW; }
float ROIPooling::GetScale() const { return this->primitive->value.AsROIPooling()->scale; }
void ROIPooling::SetPooledH(int pooled_h) { this->primitive->value.AsROIPooling()->pooledH = pooled_h; }
void ROIPooling::SetPooledW(int pooled_w) { this->primitive->value.AsROIPooling()->pooledW = pooled_w; }
void ROIPooling::SetScale(float scale) { this->primitive->value.AsROIPooling()->scale = scale; }
#else
int ROIPooling::GetPooledH() const { return this->primitive->value_as_ROIPooling()->pooledH(); }
int ROIPooling::GetPooledW() const { return this->primitive->value_as_ROIPooling()->pooledW(); }
float ROIPooling::GetScale() const { return this->primitive->value_as_ROIPooling()->scale(); }
void ROIPooling::SetPooledH(int pooled_h) {}
void ROIPooling::SetPooledW(int pooled_w) {}
void ROIPooling::SetScale(float scale) {}
#endif
int ROIPooling::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
if (inputs_.size() != kDoubleNum) {
MS_LOG(ERROR) << "inputs number is not equal to " << kDoubleNum;
return 1;
}
auto input = inputs_.front();
if (input == nullptr) {
return 1;
}
auto roi = inputs_.at(1);
if (roi == nullptr) {
return 1;
}
auto output = outputs_.front();
if (output == nullptr) {
return 1;
}
auto new_h = GetPooledH();
auto new_w = GetPooledW();
auto shape_data = roi->shape();
std::vector<int> output_shape;
output_shape.push_back(shape_data[0]);
output_shape.push_back(new_h);
output_shape.push_back(new_w);
output_shape.push_back(input->Channel());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,61 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/scatter_nd.h"
namespace mindspore {
namespace {
constexpr int kScatterNDInputNum = 3;
constexpr int kScatterNDOutputNum = 1;
constexpr int kScatterShapeIndex = 0;
constexpr int kScatterIndicesIndex = 1;
constexpr int kScatterUpdateIndex = 2;
} // namespace
int ScatterND::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
if (inputs_.size() != kScatterNDInputNum) {
MS_LOG(ERROR) << "inputs number is not equal to " << kScatterNDInputNum;
return 1;
}
if (outputs_.size() != kScatterNDOutputNum) {
MS_LOG(ERROR) << "outputs number is not equal to " << kScatterNDInputNum;
return 1;
}
auto shape = inputs_.at(kScatterShapeIndex);
if (shape == nullptr) {
MS_LOG(ERROR) << "shape null pointer dereferencing.";
return 1;
}
auto indices = inputs_.at(kScatterIndicesIndex);
if (indices == nullptr) {
MS_LOG(ERROR) << "indices null pointer dereferencing.";
return 1;
}
auto update = inputs_.at(kScatterUpdateIndex);
if (update == nullptr) {
MS_LOG(ERROR) << "update null pointer dereferencing.";
return 1;
}
auto output = outputs_.front();
auto shape_data = reinterpret_cast<int *>(shape->Data());
std::vector<int> out_shape(shape_data, shape_data + shape->DataSize());
output->set_shape(out_shape);
output->set_data_type(update->data_type());
output->SetFormat(update->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,52 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/shape.h"
namespace mindspore {
namespace {
constexpr int kShapeInputNum = 1;
constexpr int kShapeOutputNum = 1;
} // namespace
int Shape::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
if (inputs_.size() != kShapeInputNum) {
MS_LOG(ERROR) << "inputs to Shape operator should be 1, but " << inputs_.size() << " is given.";
return 1;
}
if (outputs_.size() != kShapeOutputNum) {
MS_LOG(ERROR) << "outputs to Shape operator should be 1, but " << outputs_.size() << " is given.";
return 1;
}
auto in_tensor = inputs_.front();
auto out_tensor = outputs_.front();
std::vector<int> out_shape;
out_shape.push_back(static_cast<int>(in_tensor->shape().size()));
auto ret_shape = out_tensor->set_shape(out_shape);
if (ret_shape != 1 || size_t(out_tensor->shape()[0]) != in_tensor->shape().size()) {
MS_LOG(ERROR) << "Set shape fails.";
return 1;
}
auto ret_dtype = out_tensor->set_data_type(in_tensor->data_type());
if (ret_dtype != in_tensor->data_type()) {
MS_LOG(ERROR) << "Set datatype fails.";
return 1;
}
return 0;
}
} // namespace mindspore

View File

@ -1,90 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/slice.h"
namespace mindspore {
namespace {
constexpr int kSliceInputNum = 1;
constexpr int kSliceOutputNum = 1;
} // namespace
#ifdef PRIMITIVE_WRITEABLE
int SliceOp::GetFormat() const { return this->primitive->value.AsSlice()->format; }
std::vector<int> SliceOp::GetBegin() const { return this->primitive->value.AsSlice()->begin; }
std::vector<int> SliceOp::GetSize() const { return this->primitive->value.AsSlice()->size; }
void SliceOp::SetFormat(int format) { this->primitive->value.AsSlice()->format = format; }
void SliceOp::SetBegin(const std::vector<int> &begin) { this->primitive->value.AsSlice()->begin = begin; }
void SliceOp::SetSize(const std::vector<int> &size) { this->primitive->value.AsSlice()->size = size; }
#else
int SliceOp::GetFormat() const { return this->primitive->value_as_Slice()->format(); }
std::vector<int> SliceOp::GetBegin() const {
auto fb_vector = this->primitive->value_as_Slice()->begin();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
std::vector<int> SliceOp::GetSize() const {
auto fb_vector = this->primitive->value_as_Slice()->size();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void SliceOp::SetFormat(int format) {}
void SliceOp::SetBegin(const std::vector<int> &begin) {}
void SliceOp::SetSize(const std::vector<int> &size) {}
#endif
int SliceOp::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
if (inputs.size() != kSliceInputNum || outputs.size() != kSliceOutputNum) {
MS_LOG(ERROR) << "input size:" << inputs.size() << ",output size:" << outputs.size();
return 1;
}
auto input = inputs.at(0);
auto input_shape = input->shape();
std::vector<int32_t> slice_begin(GetBegin().begin(), GetBegin().end());
std::vector<int32_t> slice_size(GetSize().begin(), GetSize().end());
std::vector<int32_t> output_shape(input_shape.size());
for (int i = 0; i < input_shape.size(); ++i) {
if (slice_size[i] < 0 && slice_size[i] != -1) {
MS_LOG(ERROR) << "Invalid size input!size[" << i << "]=" << slice_size[i];
return 1;
}
if (slice_begin[i] < 0) {
MS_LOG(ERROR) << "Invalid begin input " << slice_begin[i] << " which should be >= 0";
return 1;
}
if (input_shape[i] <= slice_begin[i]) {
MS_LOG(ERROR) << "Invalid begin input!begin[" << i << "]=" << slice_begin[i]
<< " which should be <= " << input_shape[i];
return 1;
}
if (slice_size[i] > (input_shape[i] - slice_begin[i])) {
MS_LOG(ERROR) << "Invalid size input " << slice_size[i]
<< " which should be <= " << input_shape[i] - slice_begin[i];
return 1;
}
output_shape[i] = slice_size[i] < 0 ? input_shape[i] - slice_begin[i] : slice_size[i];
}
outputs[0]->set_shape(output_shape);
outputs[0]->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,43 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/softmax.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int SoftMax::GetAxis() const { return this->primitive->value.AsSoftMax()->axis; }
void SoftMax::SetAxis(int axis) { this->primitive->value.AsSoftMax()->axis = axis; }
#else
int SoftMax::GetAxis() const { return this->primitive->value_as_SoftMax()->axis(); }
void SoftMax::SetAxis(int axis) {}
#endif
int SoftMax::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,110 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/space_to_batch.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> SpaceToBatch::GetBlockShape() const { return this->primitive->value.AsSpaceToBatch()->blockShape; }
std::vector<int> SpaceToBatch::GetPaddings() const { return this->primitive->value.AsSpaceToBatch()->paddings; }
void SpaceToBatch::SetBlockShape(const std::vector<int> &block_shape) {
this->primitive->value.AsSpaceToBatch()->blockShape = block_shape;
}
void SpaceToBatch::SetPaddings(const std::vector<int> &paddings) {
this->primitive->value.AsSpaceToBatch()->paddings = paddings;
}
#else
std::vector<int> SpaceToBatch::GetBlockShape() const {
auto fb_vector = this->primitive->value_as_SpaceToBatch()->blockShape();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
std::vector<int> SpaceToBatch::GetPaddings() const {
auto fb_vector = this->primitive->value_as_SpaceToBatch()->paddings();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void SpaceToBatch::SetBlockShape(const std::vector<int> &block_shape) {}
void SpaceToBatch::SetPaddings(const std::vector<int> &paddings) {}
#endif
namespace {
constexpr int kSpaceToBatchNDOutputNum = 1;
constexpr int kSpaceToBatchNDInputNum = 1;
constexpr int kBlockSizesSize = 2;
constexpr int kPaddingsSize = 4;
} // namespace
int SpaceToBatch::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
if (outputs.size() != kSpaceToBatchNDOutputNum || inputs.size() != kSpaceToBatchNDInputNum) {
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
return 1;
}
auto input = inputs.at(0);
if (input->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "space_to_batch only support NHWC now!";
return 1;
}
auto input_shape = input->shape();
if (input_shape.size() != kDimension_4d) {
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
return 1;
}
if (GetBlockShape().size() != kBlockSizesSize) {
MS_LOG(ERROR) << "Block shape size should be " << kBlockSizesSize;
return 1;
}
if (GetPaddings().size() != kPaddingsSize) {
MS_LOG(ERROR) << "Crops size should be " << kPaddingsSize;
return 1;
}
for (int &iter : GetBlockShape()) {
block_sizes_.emplace_back(iter);
}
in_shape_.clear();
padded_in_shape_.clear();
paddings_.clear();
in_shape_.emplace_back(input_shape.at(NHWC_N));
padded_in_shape_.emplace_back(input_shape.at(NHWC_N));
for (int i = 0; i < kBlockSizesSize; i++) {
in_shape_.emplace_back(input_shape.at(i + 1));
padded_in_shape_.emplace_back(input_shape.at(i + 1) + (paddings_.at(2 * i) + paddings_.at(2 * i + 1)));
paddings_.emplace_back(paddings_.at(2 * i));
paddings_.emplace_back(paddings_.at(2 * i + 1));
if (paddings_.back() % block_sizes_.at(i)) {
MS_LOG(ERROR) << "Padded shape does not divide block size " << block_sizes_.at(i);
return 1;
}
}
in_shape_.emplace_back(input_shape.at(NHWC_C));
padded_in_shape_.emplace_back(input_shape.at(NHWC_C));
std::vector<int32_t> output_shape(input_shape.size());
output_shape[NHWC_N] = input_shape[NHWC_N] * (block_sizes_[NHWC_N] * block_sizes_[NHWC_H]);
output_shape[NHWC_H] = input_shape[NHWC_H] / block_sizes_[NHWC_N];
output_shape[NHWC_W] = input_shape[NHWC_W] / block_sizes_[NHWC_H];
output_shape[NHWC_C] = input_shape[NHWC_C];
outputs[0]->set_shape(output_shape);
outputs[0]->set_data_type(input->data_type());
return 0;
}
} // namespace mindspore

View File

@ -1,73 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/space_to_depth.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int SpaceToDepth::GetBlockSize() const { return this->primitive->value.AsSpaceToDepth()->blockSize; }
int SpaceToDepth::GetFormat() const { return this->primitive->value.AsSpaceToDepth()->format; }
void SpaceToDepth::SetBlockSize(int block_size) { this->primitive->value.AsSpaceToDepth()->blockSize = block_size; }
void SpaceToDepth::SetFormat(int format) { this->primitive->value.AsSpaceToDepth()->format = format; }
#else
int SpaceToDepth::GetBlockSize() const { return this->primitive->value_as_SpaceToDepth()->blockSize(); }
int SpaceToDepth::GetFormat() const { return this->primitive->value_as_SpaceToDepth()->format(); }
void SpaceToDepth::SetBlockSize(int block_size) {}
void SpaceToDepth::SetFormat(int format) {}
#endif
namespace {
constexpr int kSpaceToDepthOutputNum = 1;
constexpr int kSpaceToDepthInputNum = 1;
} // namespace
int SpaceToDepth::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
if (outputs.size() != kSpaceToDepthOutputNum || inputs.size() != kSpaceToDepthInputNum) {
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
return 1;
}
auto input = inputs.at(0);
if (input->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "space_to_depth only support NHWC now!";
return 1;
}
auto input_shape = input->shape();
if (input_shape.size() != kDimension_4d) {
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
return 1;
}
int32_t block_size = GetBlockSize();
if (input_shape[NHWC_C] % (block_size * block_size) != 0 || input_shape[NHWC_C] == 0) {
MS_LOG(ERROR) << "input dimension c size " << input_shape[NHWC_C] << " should be mulitple of block_size("
<< block_size << ") * block_size)!";
return 1;
}
std::vector<int32_t> output_shape(input_shape.size());
output_shape[NHWC_N] = input_shape[NHWC_N];
output_shape[NHWC_H] = input_shape[NHWC_H] / block_size;
output_shape[NHWC_W] = input_shape[NHWC_W] / block_size;
output_shape[NHWC_C] = input_shape[NHWC_C] * (block_size * block_size);
outputs[0]->set_shape(output_shape);
outputs[0]->set_data_type(input->data_type());
return 0;
}
} // namespace mindspore

View File

@ -1,83 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/split.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int Split::GetNumberSplit() const { return this->primitive->value.AsSplit()->numberSplit; }
std::vector<int> Split::GetSizeSplits() const { return this->primitive->value.AsSplit()->sizeSplits; }
int Split::GetSplitDim() const { return this->primitive->value.AsSplit()->splitDim; }
void Split::SetNumberSplit(int number_split) { this->primitive->value.AsSplit()->numberSplit = number_split; }
void Split::SetSizeSplits(const std::vector<int> &size_splits) {
this->primitive->value.AsSplit()->sizeSplits = size_splits;
}
void Split::SetSplitDim(int split_dim) { this->primitive->value.AsSplit()->splitDim = split_dim; }
#else
int Split::GetNumberSplit() const { return this->primitive->value_as_Split()->numberSplit(); }
std::vector<int> Split::GetSizeSplits() const {
auto fb_vector = this->primitive->value_as_Split()->sizeSplits();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
int Split::GetSplitDim() const { return this->primitive->value_as_Split()->splitDim(); }
void Split::SetNumberSplit(int number_split) {}
void Split::SetSizeSplits(const std::vector<int> &size_splits) {}
void Split::SetSplitDim(int split_dim) {}
#endif
namespace {
constexpr int kSplitInputNum = 1;
} // namespace
int Split::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
MS_ASSERT(spilt_prim != nullptr);
if (inputs_.size() != kSplitInputNum) {
MS_LOG(ERROR) << "inputs number is not equal to " << kSplitInputNum;
return 1;
}
auto output = outputs_.front();
if (output == nullptr) {
MS_LOG(ERROR) << "output null pointer dereferencing.";
return 1;
}
int number_split = GetNumberSplit();
if (outputs_.size() != number_split) {
MS_LOG(ERROR) << "outputs number is not equal to " << number_split;
return 1;
}
int split_dim = GetSplitDim();
std::vector<int> input_shape = input->shape();
std::vector<int> size_split;
size_split.insert(size_split.begin(), GetSizeSplits().begin(), GetSizeSplits().end());
for (int i = 0; i < number_split; ++i) {
std::vector<int> output_shape;
output_shape.insert(output_shape.begin(), input_shape.begin(), input_shape.end());
auto split_dim_i = size_split.empty() ? input_shape[split_dim] / number_split : size_split[i];
output_shape[split_dim] = split_dim_i;
outputs_[i]->set_shape(output_shape);
outputs_[i]->set_data_type(input->data_type());
outputs_[i]->SetFormat(input->GetFormat());
}
return 0;
}
} // namespace mindspore

View File

@ -1,84 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/squeeze.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> Squeeze::GetAxis() const { return this->primitive->value.AsSqueeze()->axis; }
void Squeeze::SetAxis(const std::vector<int> &axis) { this->primitive->value.AsSqueeze()->axis = axis; }
#else
std::vector<int> Squeeze::GetAxis() const {
auto fb_vector = this->primitive->value_as_Squeeze()->axis();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void Squeeze::SetAxis(const std::vector<int> &axis) {}
#endif
namespace {
constexpr int kSqueezeInputNum = 1;
constexpr int kSqueezeOutputNum = 1;
} // namespace
int Squeeze::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
if (kSqueezeInputNum != inputs_.size()) {
MS_LOG(ERROR) << "Add should has " << kSqueezeInputNum << " inputs";
return -1;
}
if (kSqueezeOutputNum != outputs_.size()) {
MS_LOG(ERROR) << "Add should has " << kSqueezeOutputNum << " outputs";
return -1;
}
auto *in_tensor = inputs_.front();
auto in_shape = in_tensor->shape();
std::vector<int> out_shape;
// todo: getAxis
auto axis = GetAxis();
std::vector<int> axes_;
for (auto iter = axis.begin(); iter != axis.end(); iter++) {
axes_.push_back(*iter);
}
if (axes_.size() == 0) {
for (int i = 0; i < in_shape.size(); i++) {
if (in_shape[i] != 1) {
out_shape.push_back(in_shape[i]);
}
}
} else {
int axisIdx = 0;
for (int i = 0; i < in_shape.size(); i++) {
if (axisIdx < axes_.size() && axes_[axisIdx] == i) {
MS_ASSERT(in_shape[i] == 1);
axisIdx++;
continue;
} else {
out_shape.push_back(in_shape[i]);
}
}
}
outputs_.front()->set_shape(out_shape);
outputs_.front()->set_data_type(in_tensor->data_type());
outputs_.front()->SetFormat(in_tensor->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,93 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/stack.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int Stack::GetAxis() const { return this->primitive->value.AsStack()->axis; }
int Stack::GetN() const { return this->primitive->value.AsStack()->n; }
std::vector<int> Stack::GetIsScale() const { return this->primitive->value.AsStack()->isScale; }
void Stack::SetAxis(int axis) { this->primitive->value.AsStack()->axis = axis; }
void Stack::SetN(int n) { this->primitive->value.AsStack()->n = n; }
void Stack::SetIsScale(const std::vector<int> &is_scale) { this->primitive->value.AsStack()->isScale = is_scale; }
#else
int Stack::GetAxis() const { return this->primitive->value_as_Stack()->axis(); }
int Stack::GetN() const { return this->primitive->value_as_Stack()->n(); }
std::vector<int> Stack::GetIsScale() const {
auto fb_vector = this->primitive->value_as_Stack()->isScale();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void Stack::SetAxis(int axis) {}
void Stack::SetN(int n) {}
void Stack::SetIsScale(const std::vector<int> &is_scale) {}
#endif
namespace {
constexpr int kStackOutputNum = 1;
constexpr int kStackMinInputNum = 2;
} // namespace
int Stack::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
if (outputs.size() != kStackOutputNum) {
MS_LOG(ERROR) << "Invalid output size:" << outputs.size();
return 1;
}
if (inputs.size() < kStackMinInputNum) {
MS_LOG(ERROR) << "Invalid input size " << inputs.size();
return 1;
}
auto input = inputs.at(0);
auto input_shape = input->shape();
std::vector<int32_t> output_shape = input_shape;
int axis = GetAxis() < 0 ? GetAxis() + input_shape.size() : GetAxis();
if (axis < 0 || axis > input_shape.size()) {
MS_LOG(ERROR) << "Invalid axis " << GetAxis();
return 1;
}
schema::Format input0_format = input->GetFormat();
for (size_t i = 1; i < inputs.size(); ++i) {
if (inputs[i]->GetFormat() != input0_format) {
MS_LOG(ERROR) << "All inputs should have the same format!";
return 1;
}
auto input_shape_tmp = inputs[i]->shape();
if (input_shape_tmp.size() != input_shape.size()) {
MS_LOG(ERROR) << "All input shape size should be the same!";
return 1;
}
for (size_t j = 0; j < input_shape.size(); ++j) {
if (input_shape_tmp[j] != input_shape[j]) {
MS_LOG(ERROR) << "All input shape should be the same!";
return 1;
}
}
}
output_shape.insert(output_shape.begin() + axis, inputs.size());
outputs[0]->set_shape(output_shape);
outputs[0]->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,221 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/strided_slice.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int StridedSlice::GetBeginMask() const { return this->primitive->value.AsStridedSlice()->beginMask; }
int StridedSlice::GetEndMask() const { return this->primitive->value.AsStridedSlice()->endMask; }
int StridedSlice::GetEllipsisMask() const { return this->primitive->value.AsStridedSlice()->ellipsisMask; }
int StridedSlice::GetNewAxisMask() const { return this->primitive->value.AsStridedSlice()->newAxisMask; }
int StridedSlice::GetShrinkAxisMask() const { return this->primitive->value.AsStridedSlice()->shrinkAxisMask; }
std::vector<int> StridedSlice::GetBegin() const { return this->primitive->value.AsStridedSlice()->begin; }
std::vector<int> StridedSlice::GetEnd() const { return this->primitive->value.AsStridedSlice()->end; }
std::vector<int> StridedSlice::GetStride() const { return this->primitive->value.AsStridedSlice()->stride; }
std::vector<int> StridedSlice::GetIsScale() const { return this->primitive->value.AsStridedSlice()->isScale; }
void StridedSlice::SetBeginMask(int begin_mask) { this->primitive->value.AsStridedSlice()->beginMask = begin_mask; }
void StridedSlice::SetEndMask(int end_mask) { this->primitive->value.AsStridedSlice()->endMask = end_mask; }
void StridedSlice::SetEllipsisMask(int ellipsis_mask) {
this->primitive->value.AsStridedSlice()->ellipsisMask = ellipsis_mask;
}
void StridedSlice::SetNewAxisMask(int new_axis_mask) {
this->primitive->value.AsStridedSlice()->newAxisMask = new_axis_mask;
}
void StridedSlice::SetShrinkAxisMask(int shrink_axis_mask) {
this->primitive->value.AsStridedSlice()->shrinkAxisMask = shrink_axis_mask;
}
void StridedSlice::SetBegin(const std::vector<int> &begin) { this->primitive->value.AsStridedSlice()->begin = begin; }
void StridedSlice::SetEnd(const std::vector<int> &end) { this->primitive->value.AsStridedSlice()->end = end; }
void StridedSlice::SetStride(const std::vector<int> &stride) {
this->primitive->value.AsStridedSlice()->stride = stride;
}
void StridedSlice::SetIsScale(const std::vector<int> &is_scale) {
this->primitive->value.AsStridedSlice()->isScale = is_scale;
}
#else
int StridedSlice::GetBeginMask() const { return this->primitive->value_as_StridedSlice()->beginMask(); }
int StridedSlice::GetEndMask() const { return this->primitive->value_as_StridedSlice()->endMask(); }
int StridedSlice::GetEllipsisMask() const { return this->primitive->value_as_StridedSlice()->ellipsisMask(); }
int StridedSlice::GetNewAxisMask() const { return this->primitive->value_as_StridedSlice()->newAxisMask(); }
int StridedSlice::GetShrinkAxisMask() const { return this->primitive->value_as_StridedSlice()->shrinkAxisMask(); }
std::vector<int> StridedSlice::GetBegin() const {
auto fb_vector = this->primitive->value_as_StridedSlice()->begin();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
std::vector<int> StridedSlice::GetEnd() const {
auto fb_vector = this->primitive->value_as_StridedSlice()->end();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
std::vector<int> StridedSlice::GetStride() const {
auto fb_vector = this->primitive->value_as_StridedSlice()->stride();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
std::vector<int> StridedSlice::GetIsScale() const {
auto fb_vector = this->primitive->value_as_StridedSlice()->isScale();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void StridedSlice::SetBeginMask(int begin_mask) {}
void StridedSlice::SetEndMask(int end_mask) {}
void StridedSlice::SetEllipsisMask(int ellipsis_mask) {}
void StridedSlice::SetNewAxisMask(int new_axis_mask) {}
void StridedSlice::SetShrinkAxisMask(int shrink_axis_mask) {}
void StridedSlice::SetBegin(const std::vector<int> &begin) {}
void StridedSlice::SetEnd(const std::vector<int> &end) {}
void StridedSlice::SetStride(const std::vector<int> &stride) {}
void StridedSlice::SetIsScale(const std::vector<int> &is_scale) {}
#endif
namespace {
constexpr int kStridedSliceOutputNum = 1;
constexpr int kStridedSliceInputNum = 1;
} // namespace
void StridedSlice::ApplyNewAxisMask() {
for (int i = 0; i < new_axis_mask_.size(); i++) {
if (new_axis_mask_.at(i)) {
ndim_ += 1;
in_shape_.insert(in_shape_.begin() + i, 1);
begins_.at(i) = 0;
ends_.at(i) = 1;
strides_.at(i) = 1;
begins_.emplace_back(0);
ends_.emplace_back(in_shape_.at(ndim_ - 1));
strides_.emplace_back(1);
begins_mask_.at(i) = false;
ends_mask_.at(i) = false;
ellipsis_mask_.at(i) = false;
shrink_axis_mask_.at(i) = false;
}
}
}
std::vector<int> StridedSlice::ApplyShrinkMask(std::vector<int> out_shape) {
auto old_out_shape = out_shape;
out_shape.clear();
for (int i = 0; i < shrink_axis_mask_.size(); i++) {
if (shrink_axis_mask_.at(i)) {
ends_.at(i) = begins_.at(i) + 1;
strides_.at(i) = 1;
} else {
out_shape.emplace_back(old_out_shape.at(i));
}
}
for (int i = shrink_axis_mask_.size(); i < old_out_shape.size(); i++) {
out_shape.emplace_back(old_out_shape.at(i));
}
return out_shape;
}
/*only one bit will be used if multiple bits are true.*/
void StridedSlice::ApplyEllipsisMask() {
for (int i = 0; i < ellipsis_mask_.size(); i++) {
if (ellipsis_mask_.at(i)) {
begins_.at(i) = 0;
ends_.at(i) = in_shape_.at(i);
break;
}
}
}
void StridedSlice::ApplyBeginMask() {
for (int i = 0; i < ndim_; i++) {
if (begins_mask_.at(i)) {
begins_.at(i) = 0;
}
}
}
void StridedSlice::ApplyEndMask() {
for (int i = 0; i < ndim_; i++) {
if (ends_mask_.at(i)) {
ends_.at(i) = in_shape_.at(i);
}
}
}
int StridedSlice::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
if (outputs.size() != kStridedSliceOutputNum) {
MS_LOG(ERROR) << "Invalid output size:" << outputs.size();
return 1;
}
if (inputs.size() != kStridedSliceInputNum) {
MS_LOG(ERROR) << "Invalid input size " << inputs.size();
return 1;
}
auto input = inputs.at(0);
MS_ASSERT(input != nullptr);
auto input_shape = input->shape();
std::vector<int> output_shape;
ndim_ = static_cast<int>(GetBegin().size());
MS_ASSERT(ndim_ == static_cast<int>(strided_slice_prim->end()->size()));
MS_ASSERT(ndim_ == static_cast<int>(strided_slice_prim->stride()->size()));
MS_ASSERT(ndim_ == static_cast<int>(input_shape.size()));
for (int i = 0; i < ndim_; i++) {
in_shape_.emplace_back(input_shape.at(i));
begins_.emplace_back((GetBegin())[i]);
ends_.emplace_back((GetEnd())[i]);
strides_.emplace_back((GetStride())[i]);
}
// set all mask to original input shape
begins_mask_.resize(ndim_);
ends_mask_.resize(ndim_);
ellipsis_mask_.resize(ndim_);
new_axis_mask_.resize(ndim_);
shrink_axis_mask_.resize(ndim_);
// convert bit to vector
for (int i = 0; i < ndim_; i++) {
begins_mask_.at(i) = static_cast<uint32_t>(GetBeginMask()) & (1 << i);
ends_mask_.at(i) = static_cast<uint32_t>(GetEndMask()) & (1 << i);
ellipsis_mask_.at(i) = static_cast<uint32_t>(GetEllipsisMask()) & (1 << i);
new_axis_mask_.at(i) = static_cast<uint32_t>(GetNewAxisMask()) & (1 << i);
shrink_axis_mask_.at(i) = static_cast<uint32_t>(GetShrinkAxisMask()) & (1 << i);
}
ApplyNewAxisMask();
ApplyBeginMask();
ApplyEndMask();
ApplyEllipsisMask();
output_shape.clear();
output_shape.resize(in_shape_.size());
for (int i = 0; i < in_shape_.size(); i++) {
if (i < ndim_ && new_axis_mask_.at(i)) {
output_shape.at(i) = 1;
} else {
output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i);
}
}
output_shape = ApplyShrinkMask(output_shape);
outputs.front()->set_shape(output_shape);
outputs.front()->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,55 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/tile.h"
#include <algorithm>
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> Tile::GetMultiples() const { return this->primitive->value.AsTile()->multiples; }
void Tile::SetMultiples(const std::vector<int> &multiples) { this->primitive->value.AsTile()->multiples = multiples; }
#else
std::vector<int> Tile::GetMultiples() const {
auto fb_vector = this->primitive->value_as_Tile()->multiples();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void Tile::SetMultiples(const std::vector<int> &multiples) {}
#endif
int Tile::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
std::vector<int> out_shape;
std::vector<int> multiples;
std::copy(GetMultiples().begin(), GetMultiples().end(), std::back_inserter(multiples));
for (size_t i = 0; i < input->shape().size(); ++i) {
int tmp = input->shape()[i] * multiples[i];
out_shape.push_back(tmp);
}
output->SetFormat(input->GetFormat());
output->set_shape(out_shape);
output->set_data_type(input->data_type());
return 0;
}
} // namespace mindspore

View File

@ -1,62 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/topk.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int TopK::GetK() const { return this->primitive->value.AsTopK()->k; }
bool TopK::GetSorted() const { return this->primitive->value.AsTopK()->sorted; }
void TopK::SetK(int k) { this->primitive->value.AsTopK()->k = k; }
void TopK::SetSorted(bool sorted) { this->primitive->value.AsTopK()->sorted = sorted; }
#else
int TopK::GetK() const { return this->primitive->value_as_TopK()->k(); }
bool TopK::GetSorted() const { return this->primitive->value_as_TopK()->sorted(); }
void TopK::SetK(int k) {}
void TopK::SetSorted(bool sorted) {}
#endif
int TopK::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
if (inputs_.size() != kSingleNum || outputs_.size() != kDoubleNum) {
MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size();
return 1;
}
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output0 = outputs_.front();
MS_ASSERT(output0 != nullptr);
auto output1 = outputs_.at(1);
MS_ASSERT(output1 != nullptr);
MS_ASSERT(topk_prim != nullptr);
auto out_shape = input->shape();
out_shape[out_shape.size() - 1] = GetK();
output0->set_shape(out_shape);
output0->set_data_type(input->data_type());
output0->SetFormat(input->GetFormat());
output1->set_shape(out_shape);
output1->set_data_type(kNumberTypeInt32);
output1->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,69 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/transpose.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> Transpose::GetPerm() const { return this->primitive->value.AsTranspose()->perm; }
bool Transpose::GetConjugate() const { return this->primitive->value.AsTranspose()->conjugate; }
void Transpose::SetPerm(const std::vector<int> &perm) { this->primitive->value.AsTranspose()->perm = perm; }
void Transpose::SetConjugate(bool conjugate) { this->primitive->value.AsTranspose()->conjugate = conjugate; }
#else
std::vector<int> Transpose::GetPerm() const {
auto fb_vector = this->primitive->value_as_Transpose()->perm();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
bool Transpose::GetConjugate() const { return this->primitive->value_as_Transpose()->conjugate(); }
void Transpose::SetPerm(const std::vector<int> &perm) {}
void Transpose::SetConjugate(bool conjugate) {}
#endif
int Transpose::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
MS_ASSERT(inputs_.size() == kSingleNum);
MS_ASSERT(outputs_.size() == kSingleNum);
int conjugate = GetConjugate();
if (conjugate) {
MS_LOG(ERROR) << "Transpose conjugate is not support currently";
return 1;
}
std::vector<int> perm;
perm.insert(perm.begin(), GetPerm().begin(), GetPerm().end());
std::vector<int> in_shape = input->shape();
std::vector<int> out_shape;
out_shape.resize(perm.size());
for (int i = 0; i < perm.size(); ++i) {
out_shape[i] = in_shape[perm[i]];
}
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,52 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/unique.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int Unique::GetOutType() const { return this->primitive->value.AsUnique()->outType; }
void Unique::SetOutType(int out_type) { this->primitive->value.AsUnique()->outType = out_type; }
#else
int Unique::GetOutType() const { return this->primitive->value_as_Unique()->outType(); }
void Unique::SetOutType(int out_type) {}
#endif
int Unique::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
if (inputs_.size() != kSingleNum || outputs_.size() != kDoubleNum) {
MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size();
return 1;
}
auto &input = inputs_.at(0);
MS_ASSERT(input != nullptr);
auto &output0 = outputs_.at(0);
MS_ASSERT(output0 != nullptr);
auto &output1 = outputs_.at(1);
MS_ASSERT(output1 != nullptr);
output0->set_shape(input->shape());
output0->set_data_type(input->data_type());
output1->set_shape(input->shape());
output1->set_data_type(kNumberTypeInt32);
output1->SetFormat(input->GetFormat());
output0->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -1,80 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/unsqueeze.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> Unsqueeze::GetAxis() const { return this->primitive->value.AsUnsqueeze()->axis; }
void Unsqueeze::SetAxis(const std::vector<int> &axis) { this->primitive->value.AsUnsqueeze()->axis = axis; }
#else
bool predicate(int n) { return n != 1; }
std::vector<int> Unsqueeze::GetAxis() const {
auto fb_vector = this->primitive->value_as_Unsqueeze()->axis();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void Unsqueeze::SetAxis(const std::vector<int> &axis) {}
#endif
int Unsqueeze::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
if (inputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "input size is invalid";
}
if (outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "output size is invalid";
}
auto dims = GetAxis().data();
auto in_shape = input->shape();
auto in_rank = in_shape.size();
auto dim_rank = GetAxis().size();
std::vector<int> out_shape;
if (dim_rank == 0) {
std::copy_if(in_shape.begin(), in_shape.end(), out_shape.begin(), [](int n) -> bool { return n != 1; });
} else {
auto sz = in_rank + dim_rank;
int in_itr = 0;
int ax_itr = 0;
for (int i = 0; i < sz; i++) {
if (ax_itr < dim_rank && dims[ax_itr] == i) {
out_shape.emplace_back(1);
ax_itr++;
} else if (ax_itr < dim_rank && dims[ax_itr] + sz == i) {
out_shape.emplace_back(1);
ax_itr++;
} else {
if (in_shape[in_itr] > 1) {
out_shape.emplace_back(in_shape[in_itr]);
}
in_itr++;
}
}
}
output->SetFormat(input->GetFormat());
output->set_shape(out_shape);
output->set_data_type(input->data_type());
return 0;
}
} // namespace mindspore

View File

@ -1,59 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/unstack.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
int Unstack::GetNum() const { return this->primitive->value.AsUnstack()->num; }
int Unstack::GetAxis() const { return this->primitive->value.AsUnstack()->axis; }
void Unstack::SetNum(int num) { this->primitive->value.AsUnstack()->num = num; }
void Unstack::SetAxis(int axis) { this->primitive->value.AsUnstack()->axis = axis; }
#else
int Unstack::GetNum() const { return this->primitive->value_as_Unstack()->num(); }
int Unstack::GetAxis() const { return this->primitive->value_as_Unstack()->axis(); }
void Unstack::SetNum(int num) {}
void Unstack::SetAxis(int axis) {}
#endif
int Unstack::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
auto input = inputs.at(0);
MS_ASSERT(input != nullptr);
auto input_shape = input->shape();
int axis = GetAxis() < 0 ? GetAxis() + input_shape.size() : GetAxis();
if (axis < 0 || axis >= input_shape.size()) {
MS_LOG(ERROR) << "Invalid axis " << GetAxis();
return 1;
}
std::vector<int> output_shape;
for (size_t i = 0; i < input_shape.size(); ++i) {
if (i != axis) {
output_shape.push_back(input_shape.at(i));
}
}
for (auto &out : outputs) {
MS_ASSERT(out != nullptr);
out->set_shape(output_shape);
out->set_data_type(input->data_type());
out->SetFormat(input->GetFormat());
}
return 0;
}
} // namespace mindspore

View File

@ -1,93 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/where.h"
namespace mindspore {
#ifdef PRIMITIVE_WRITEABLE
std::vector<bool> Where::GetCondition() const { return this->primitive->value.AsWhere()->condition; }
void Where::SetCondition(const std::vector<bool> &condition) {
this->primitive->value.AsWhere()->condition = condition;
}
#else
std::vector<bool> Where::GetCondition() const {
auto fb_vector = this->primitive->value_as_Where()->condition();
return std::vector<bool>(fb_vector->begin(), fb_vector->end());
}
void Where::SetCondition(const std::vector<bool> &condition) {}
#endif
int Where::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "where input or output number invalid, Input size:" << inputs_.size()
<< ", output size: " << outputs_.size();
return 1;
}
if (inputs_.size() < 3) {
MS_LOG(ERROR) << "Input shape tensors should b";
return 1;
}
auto input0 = inputs_.at(0);
auto input1 = inputs_.at(1);
auto input2 = inputs_.at(2);
int num = input0->ElementsNum();
int num1 = input1->ElementsNum();
int num2 = input2->ElementsNum();
int nummax = num > num1 ? num : (num1 > num2 ? num1 : num2);
auto shape_tmp = inputs_.at(0)->shape();
auto shape_tmp1 = inputs_.at(1)->shape();
auto shape_tmp2 = inputs_.at(2)->shape();
int axisout = 0;
int temp = 0;
for (int j = 0; j < shape_tmp.size(); j++) {
if (shape_tmp[j] == shape_tmp1[j] && shape_tmp[j] != shape_tmp2[j]) {
axisout = j;
break;
}
if (shape_tmp[j] == shape_tmp2[j] && shape_tmp[j] != shape_tmp1[j]) {
axisout = j;
break;
}
if (shape_tmp1[j] == shape_tmp2[j] && shape_tmp[j] != shape_tmp1[j]) {
axisout = j;
break;
}
temp += 1;
if (temp == shape_tmp.size()) {
outputs_[0]->set_shape(shape_tmp);
output->set_data_type(input->data_type());
return 0;
}
}
auto output_shape = shape_tmp;
output_shape[axisout] = nummax;
outputs_[0]->set_shape(output_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return 0;
}
} // namespace mindspore

View File

@ -21,6 +21,7 @@
#include <vector>
#include <memory>
#include "schema/model_generated.h"
#include "src/ops/primitive_c.h"
namespace mindspore {
#define MS_API __attribute__((visibility("default")))
@ -34,7 +35,7 @@ class ModelImpl;
/// \brief Primitive defined as prototype of operator.
///
/// \note List public class and interface for reference.
class Primitive;
class PrimitiveC;
/// \brief Model defined model in MindSpore Lite for managing graph.
class MS_API Model {
@ -60,7 +61,7 @@ class MS_API Model {
/// \param[in] name Define name of primitive to be returned.
///
/// \return the pointer of MindSpore Lite Primitive.
lite::Primitive *GetOp(const std::string &name) const;
PrimitiveC *GetOp(const std::string &name) const;
/// \brief Get graph defined in flatbuffers.
///
@ -97,7 +98,7 @@ class MS_API ModelBuilder {
/// \param[in] inputs Define input edge of primitive to be added.
///
/// \return ID of the added primitive.
virtual std::string AddOp(const lite::Primitive &op, const std::vector<OutEdge> &inputs) = 0;
virtual std::string AddOp(const PrimitiveC &op, const std::vector<OutEdge> &inputs) = 0;
/// \brief Finish constructing the model.
///

View File

@ -1,64 +1,64 @@
set(LITE_SRC
${CMAKE_CURRENT_SOURCE_DIR}/common/graph_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/common/ms_tensor_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/allocator.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/runtime_api.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/thread_pool.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/workspace_pool.cc
${CMAKE_CURRENT_SOURCE_DIR}/ir/tensor.cc
${CMAKE_CURRENT_SOURCE_DIR}/context.cc
${CMAKE_CURRENT_SOURCE_DIR}/executor.cc
${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc
${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc
${CMAKE_CURRENT_SOURCE_DIR}/populate_parameter.cc
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc
)
${CMAKE_CURRENT_SOURCE_DIR}/common/graph_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/common/ms_tensor_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/allocator.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/runtime_api.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/thread_pool.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/workspace_pool.cc
${CMAKE_CURRENT_SOURCE_DIR}/ir/tensor.cc
${CMAKE_CURRENT_SOURCE_DIR}/context.cc
${CMAKE_CURRENT_SOURCE_DIR}/executor.cc
${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc
${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc
${CMAKE_CURRENT_SOURCE_DIR}/populate_parameter.cc
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc
)
if (SUPPORT_GPU)
list(APPEND LITE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/subgraph_opencl_kernel.cc)
list(APPEND LITE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/utils.cc)
endif()
endif ()
set(LITE_SRC
${LITE_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/model.cc
)
${LITE_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/model.cc
)
if (SUPPORT_GPU)
set(LITE_SRC
${LITE_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_executor.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_allocator.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_runtime.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_wrapper.cc
)
set(LITE_SRC
${LITE_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_executor.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_allocator.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_runtime.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_wrapper.cc
)
endif ()
set(ANF_SRC
${ANF_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/ir/meta_tensor_extends.cc
)
${ANF_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/ir/meta_tensor_extends.cc
)
add_library(mindspore-lite SHARED ${LITE_SRC} ${ANF_SRC})
target_link_libraries(mindspore-lite
cpu_kernel_mid_
ops_mid_
)
cpu_kernel_mid_
c_ops_mid
)
add_subdirectory(runtime/kernel/arm)
if (PLATFORM_ARM32 OR PLATFORM_ARM64)
target_link_libraries(mindspore-lite log)
endif()
target_link_libraries(mindspore-lite log)
endif ()
if (BUILD_MINDDATA)
target_link_libraries(mindspore-lite minddata-eager minddata-lite)
target_link_libraries(mindspore-lite minddata-eager minddata-lite)
endif ()
add_subdirectory(ops)
if("${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND (PLATFORM_ARM64 OR PLATFORM_ARM32))
add_custom_command(TARGET mindspore-lite POST_BUILD
COMMAND ${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip
${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so)
endif()
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND (PLATFORM_ARM64 OR PLATFORM_ARM32))
add_custom_command(TARGET mindspore-lite POST_BUILD
COMMAND ${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip
${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so)
endif ()

View File

@ -49,7 +49,13 @@ int Executor::Run(std::vector<tensor::Tensor *> &in_tensors, std::vector<tensor:
MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name();
return ret;
}
MS_LOG(INFO) << "out_tensors";
auto tensors = kernel->out_tensors();
MS_LOG(INFO) << kernel->name();
for (int i = 0; i < tensors.size(); ++i) {
auto tensor = tensors[i];
MS_LOG(INFO) << tensor->ToString();
}
if (after != nullptr) {
if (!after(PackToMSTensors(kernel->in_tensors()), PackToMSTensors(kernel->out_tensors()),
{kernel->name(), kernel->type_str()})) {

View File

@ -1,19 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ir/primitive_value.h"

View File

@ -1,47 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVE_H_
#define MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVE_H_
#include "ir/value.h"
#include "src/ops/ops.h"
namespace mindspore::lite {
class PrimitiveValue : public Value {
public:
explicit PrimitiveValue(const lite::Primitive *prim) : primitive(prim) {}
const lite::Primitive *GetPrimitive() const {
return this->primitive;
}
MS_DECLARE_PARENT(PrimitiveValue, Value)
bool operator==(const Value &rhs) const override {
if (rhs.isa<PrimitiveValue>()) {
auto other_prim = static_cast<const PrimitiveValue &>(rhs);
return *this == other_prim;
} else {
return false;
}
}
protected:
const lite::Primitive *primitive = nullptr;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVE_H_

View File

@ -124,13 +124,13 @@ const kernel::KernelCreator *KernelRegistry::GetCreatorArrays() { return creator
kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<tensor::Tensor *> &in_tensors,
const std::vector<tensor::Tensor *> &out_tensors,
const lite::Primitive *primitive, const Context *ctx,
const PrimitiveC *primitive, const Context *ctx,
const kernel::KernelKey &key) {
MS_EXCEPTION_IF_NULL(primitive);
MS_EXCEPTION_IF_NULL(ctx);
auto parameter = kernel::PopulateParameter(primitive);
if (parameter == nullptr) {
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(primitive->Type());
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << primitive->Type();
return nullptr;
}
auto creator = GetCreator(key);

View File

@ -40,7 +40,7 @@ class KernelRegistry {
kernel::KernelCreator creator);
bool Merge(const std::unordered_map<kernel::KernelKey, kernel::KernelCreator> &newCreators);
kernel::LiteKernel *GetKernel(const std::vector<tensor::Tensor *> &in_tensors,
const std::vector<tensor::Tensor *> &out_tensors, const lite::Primitive *primitive,
const std::vector<tensor::Tensor *> &out_tensors, const PrimitiveC *primitive,
const Context *ctx, const kernel::KernelKey &key);
protected:

View File

@ -24,8 +24,8 @@
#include "src/runtime/kernel/arm/nnacl/op_base.h"
#include "include/context.h"
#include "src/ir/tensor.h"
#include "src/ops/ops.h"
#include "include/errorcode.h"
#include "src/ops/primitive_c.h"
#ifdef ENABLE_FP16
using FLOAT_t = float16_t;
@ -59,7 +59,7 @@ class LiteKernel {
LiteKernel() = default;
LiteKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &in_tensors,
const std::vector<lite::tensor::Tensor *> &out_tensors, const lite::Context *ctx,
const lite::Primitive *primitive)
const mindspore::lite::PrimitiveC *primitive)
: op_parameter_(parameter),
in_tensors_(in_tensors),
out_tensors_(out_tensors),
@ -81,7 +81,7 @@ class LiteKernel {
virtual int Prepare() {
if (!InferShapeDone()) {
(const_cast<lite::Primitive *>(primitive_))->InferShape(in_tensors_, out_tensors_);
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->InferShape(in_tensors_, out_tensors_);
if (need_reinit_) {
Init();
}
@ -154,7 +154,7 @@ class LiteKernel {
void set_need_reinit() { need_reinit_ = true; }
const lite::Primitive *GetPrimitive() const { return primitive_; }
const mindspore::lite::PrimitiveC *GetPrimitive() const { return primitive_; }
protected:
bool InferShapeDone() { return !(primitive_ != nullptr && !primitive_->GetInferFlag()) && true; }
@ -162,7 +162,7 @@ class LiteKernel {
KernelKey desc_;
std::string name_;
OpParameter *op_parameter_ = nullptr;
const lite::Primitive *primitive_ = nullptr;
const mindspore::lite::PrimitiveC *primitive_ = nullptr;
const lite::Context *context_ = nullptr;
// tensor will free in ~lite_session()
std::vector<lite::tensor::Tensor *> in_tensors_;
@ -181,7 +181,7 @@ class SubGraphKernel : public LiteKernel {
const std::vector<kernel::LiteKernel *> &in_kernels,
const std::vector<kernel::LiteKernel *> &out_kernels,
const std::vector<kernel::LiteKernel *> &nodes, const lite::Context *ctx,
const lite::Primitive *primitive)
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(nullptr, inputs, outputs, ctx, primitive), nodes_(nodes) {
in_kernels_ = in_kernels;
out_kernels_ = out_kernels;
@ -198,7 +198,8 @@ class SubGraphKernel : public LiteKernel {
typedef LiteKernel *(*KernelCreator)(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, OpParameter *parameter,
const lite::Context *ctx, const KernelKey &desc, const lite::Primitive *primitive);
const lite::Context *ctx, const KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive);
class LiteKernelUtil {
public:

View File

@ -14,9 +14,101 @@
* limitations under the License.
*/
#include "src/ops/unique.h"
#include "src/ops/space_to_batch.h"
#include "src/ops/conv2d.h"
#include "src/ops/roi_pooling.h"
#include "src/ops/topk.h"
#include "src/ops/broadcast_to.h"
#include "src/ops/unsqueeze.h"
#include "src/ops/unstack.h"
#include "src/ops/depth_to_space.h"
#include "src/ops/batch_to_space.h"
#include "src/ops/prior_box.h"
#include "src/ops/lstm.h"
#include "src/ops/softmax.h"
#include "src/ops/activation.h"
#include "src/ops/deconv2d.h"
#include "src/ops/reduce.h"
#include "src/ops/pooling.h"
#include "src/ops/fused_batchnorm.h"
#include "src/ops/batch_norm.h"
#include "src/ops/power.h"
#include "src/ops/range.h"
#include "src/ops/add.h"
#include "src/ops/sub.h"
#include "src/ops/div.h"
#include "src/ops/bias_add.h"
#include "src/ops/expand_dims.h"
#include "src/ops/full_connection.h"
#include "src/ops/shape.h"
#include "src/ops/elu.h"
#include "src/ops/embedding_lookup.h"
#include "src/ops/quant_dtype_cast.h"
#include "src/ops/matmul.h"
#include "src/ops/resize.h"
#include "src/ops/tile.h"
#include "src/ops/one_hot.h"
#include "src/ops/space_to_depth.h"
#include "src/ops/split.h"
#include "src/ops/argmax.h"
#include "src/ops/argmin.h"
#include "src/ops/cast.h"
#include "src/ops/reshape.h"
#include "src/ops/scale.h"
#include "src/ops/concat.h"
#include "src/ops/nchw2nhwc.h"
#include "src/ops/slice.h"
#include "src/ops/squeeze.h"
#include "src/ops/flatten.h"
#include "src/ops/mean.h"
#include "src/ops/nhwc2nchw.h"
#include "src/ops/stack.h"
#include "src/ops/crop.h"
#include "src/ops/addn.h"
#include "src/ops/gather.h"
#include "src/ops/gather_nd.h"
#include "src/ops/local_response_normalization.h"
#include "src/ops/pad.h"
#include "src/ops/prelu.h"
#include "src/ops/caffe_p_relu.h"
#include "src/ops/reverse_sequence.h"
#include "src/ops/dedepthwise_conv2d.h"
#include "src/ops/depthwise_conv2d.h"
#include "src/ops/mul.h"
#include "src/ops/eltwise.h"
#include "src/ops/fill.h"
#include "src/ops/transpose.h"
#include "src/ops/log.h"
#include "src/ops/abs.h"
#include "src/ops/sin.h"
#include "src/ops/cos.h"
#include "src/ops/sqrt.h"
#include "src/ops/square.h"
#include "src/ops/exp.h"
#include "src/ops/rsqrt.h"
#include "src/ops/maximum.h"
#include "src/ops/minimum.h"
#include "src/ops/strided_slice.h"
#include "src/ops/reverse.h"
#include "src/ops/logical_and.h"
#include "src/ops/logical_or.h"
#include "src/ops/logical_not.h"
#include "src/ops/floor_div.h"
#include "src/ops/floor_mod.h"
#include "src/ops/equal.h"
#include "src/ops/not_equal.h"
#include "src/ops/less.h"
#include "src/ops/less_equal.h"
#include "src/ops/greater_equal.h"
#include "src/ops/greater.h"
#include "src/ops/floor.h"
#include "src/ops/squared_difference.h"
#include "src/ops/ceil.h"
#include "src/ops/round.h"
#include "src/ops/primitive_c.h"
#include "include/model.h"
#include "utils/log_adapter.h"
#include "src/ops/ops.h"
namespace mindspore::lite {
@ -28,19 +120,19 @@ class ModelImpl {
meta_graph_ = schema::GetMetaGraph(model_buf);
}
virtual ~ModelImpl();
lite::Primitive *GetOp(const std::string &name) const;
PrimitiveC *GetOp(const std::string &name) const;
const schema::MetaGraph *meta_graph() const;
void FreeMetaGraph();
int BuildOps();
protected:
lite::Primitive *CopyPrimitive(const schema::Primitive *src_prim);
PrimitiveC *CopyPrimitive(const schema::Primitive *src_prim);
protected:
const char *model_buf_;
size_t buf_size_;
const schema::MetaGraph *meta_graph_ = nullptr;
std::map<std::string, lite::Primitive *> ops_;
std::map<std::string, PrimitiveC *> ops_;
};
ModelImpl *ModelImpl::Import(const char *model_buf, size_t size) {
@ -72,7 +164,7 @@ ModelImpl *ModelImpl::Import(const char *model_buf, size_t size) {
return model;
}
lite::Primitive *ModelImpl::GetOp(const std::string &name) const {
PrimitiveC *ModelImpl::GetOp(const std::string &name) const {
auto iter = ops_.find(name);
if (iter == ops_.end()) {
return nullptr;
@ -96,178 +188,178 @@ void ModelImpl::FreeMetaGraph() {
const schema::MetaGraph *ModelImpl::meta_graph() const { return this->meta_graph_; }
lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *src_prim) {
PrimitiveC *ModelImpl::CopyPrimitive(const schema::Primitive *src_prim) {
MS_EXCEPTION_IF_NULL(src_prim);
auto op_type = src_prim->value_type();
switch (op_type) {
case schema::PrimitiveType_SoftMax:
return new lite::SoftMax(const_cast<schema::Primitive *>(src_prim));
return new SoftMax(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Activation:
return new lite::Activation(const_cast<schema::Primitive *>(src_prim));
return new Activation(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Conv2D:
return new lite::Conv2D(const_cast<schema::Primitive *>(src_prim));
return new Conv2D(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_DeConv2D:
return new lite::DeConv2D(const_cast<schema::Primitive *>(src_prim));
return new DeConv2D(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Reduce:
return new lite::Reduce(const_cast<schema::Primitive *>(src_prim));
return new Reduce(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Pooling:
return new lite::Pooling(const_cast<schema::Primitive *>(src_prim));
return new Pooling(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_DepthwiseConv2D:
return new lite::DepthwiseConv2D(const_cast<schema::Primitive *>(src_prim));
return new DepthwiseConv2D(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_FusedBatchNorm:
return new lite::FusedBatchNorm(const_cast<schema::Primitive *>(src_prim));
return new FusedBatchNorm(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_BatchNorm:
return new lite::BatchNorm(const_cast<schema::Primitive *>(src_prim));
return new BatchNorm(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_FullConnection:
return new lite::FullConnection(const_cast<schema::Primitive *>(src_prim));
return new FullConnection(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Power:
return new lite::Power(const_cast<schema::Primitive *>(src_prim));
return new Power(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Range:
return new lite::Range(const_cast<schema::Primitive *>(src_prim));
return new Range(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Mul:
return new lite::Mul(const_cast<schema::Primitive *>(src_prim));
return new Mul(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Add:
return new lite::Add(const_cast<schema::Primitive *>(src_prim));
return new Add(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Sub:
return new lite::Sub(const_cast<schema::Primitive *>(src_prim));
return new Sub(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Div:
return new lite::Div(const_cast<schema::Primitive *>(src_prim));
return new Div(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_BiasAdd:
return new lite::BiasAdd(const_cast<schema::Primitive *>(src_prim));
return new BiasAdd(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_ExpandDims:
return new lite::ExpandDims(const_cast<schema::Primitive *>(src_prim));
return new ExpandDims(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_ArgMax:
return new lite::ArgMax(const_cast<schema::Primitive *>(src_prim));
return new ArgMax(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_ArgMin:
return new lite::ArgMin(const_cast<schema::Primitive *>(src_prim));
return new ArgMin(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Cast:
return new lite::Cast(const_cast<schema::Primitive *>(src_prim));
return new Cast(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Reshape:
return new lite::Reshape(const_cast<schema::Primitive *>(src_prim));
return new Reshape(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Scale:
return new lite::Scale(const_cast<schema::Primitive *>(src_prim));
return new Scale(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Eltwise:
return new lite::Eltwise(const_cast<schema::Primitive *>(src_prim));
return new Eltwise(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Concat:
return new lite::Concat(const_cast<schema::Primitive *>(src_prim));
return new Concat(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Fill:
return new lite::Fill(const_cast<schema::Primitive *>(src_prim));
return new Fill(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Transpose:
return new lite::Transpose(const_cast<schema::Primitive *>(src_prim));
return new Transpose(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Slice:
return new lite::Slice(const_cast<schema::Primitive *>(src_prim));
return new SliceOp(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Squeeze:
return new lite::Squeeze(const_cast<schema::Primitive *>(src_prim));
return new Squeeze(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Nchw2Nhwc:
return new lite::Nchw2Nhwc(const_cast<schema::Primitive *>(src_prim));
return new Nchw2Nhwc(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Nhwc2Nchw:
return new lite::Nhwc2Nchw(const_cast<schema::Primitive *>(src_prim));
return new Nhwc2Nchw(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Flatten:
return new lite::Flatten(const_cast<schema::Primitive *>(src_prim));
return new Flatten(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Mean:
return new lite::Mean(const_cast<schema::Primitive *>(src_prim));
return new Mean(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Stack:
return new lite::Stack(const_cast<schema::Primitive *>(src_prim));
return new Stack(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Crop:
return new lite::Crop(const_cast<schema::Primitive *>(src_prim));
return new Crop(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_SquaredDifference:
return new lite::SquaredDifference(const_cast<schema::Primitive *>(src_prim));
return new SquaredDifference(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_AddN:
return new lite::AddN(const_cast<schema::Primitive *>(src_prim));
return new AddN(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Abs:
return new lite::Abs(const_cast<schema::Primitive *>(src_prim));
return new Abs(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Sin:
return new lite::Sin(const_cast<schema::Primitive *>(src_prim));
return new Sin(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Cos:
return new lite::Cos(const_cast<schema::Primitive *>(src_prim));
return new Cos(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Log:
return new lite::Log(const_cast<schema::Primitive *>(src_prim));
return new Log(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Sqrt:
return new lite::Sqrt(const_cast<schema::Primitive *>(src_prim));
return new Sqrt(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Rsqrt:
return new lite::Rsqrt(const_cast<schema::Primitive *>(src_prim));
return new Rsqrt(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Square:
return new lite::Square(const_cast<schema::Primitive *>(src_prim));
return new Square(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Exp:
return new lite::Exp(const_cast<schema::Primitive *>(src_prim));
return new Exp(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Gather:
return new lite::Gather(const_cast<schema::Primitive *>(src_prim));
return new Gather(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_GatherNd:
return new lite::GatherNd(const_cast<schema::Primitive *>(src_prim));
return new GatherNd(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_LocalResponseNormalization:
return new lite::LocalResponseNormalization(const_cast<schema::Primitive *>(src_prim));
return new LocalResponseNormalization(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Maximum:
return new lite::Maximum(const_cast<schema::Primitive *>(src_prim));
return new Maximum(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Minimum:
return new lite::Minimum(const_cast<schema::Primitive *>(src_prim));
return new Minimum(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Pad:
return new lite::Pad(const_cast<schema::Primitive *>(src_prim));
return new Pad(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_StridedSlice:
return new lite::StridedSlice(const_cast<schema::Primitive *>(src_prim));
return new StridedSlice(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Prelu:
return new lite::Prelu(const_cast<schema::Primitive *>(src_prim));
return new Prelu(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_CaffePReLU:
return new lite::CaffePReLU(const_cast<schema::Primitive *>(src_prim));
return new CaffePReLU(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Round:
return new lite::Round(const_cast<schema::Primitive *>(src_prim));
return new Round(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Reverse:
return new lite::Reverse(const_cast<schema::Primitive *>(src_prim));
return new Reverse(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_ReverseSequence:
return new lite::ReverseSequence(const_cast<schema::Primitive *>(src_prim));
return new ReverseSequence(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_LogicalAnd:
return new lite::LogicalAnd(const_cast<schema::Primitive *>(src_prim));
return new LogicalAnd(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_LogicalOr:
return new lite::LogicalOr(const_cast<schema::Primitive *>(src_prim));
return new LogicalOr(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_LogicalNot:
return new lite::LogicalNot(const_cast<schema::Primitive *>(src_prim));
return new LogicalNot(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_FloorDiv:
return new lite::FloorDiv(const_cast<schema::Primitive *>(src_prim));
return new FloorDiv(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_FloorMod:
return new lite::FloorMod(const_cast<schema::Primitive *>(src_prim));
return new FloorMod(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Equal:
return new lite::Equal(const_cast<schema::Primitive *>(src_prim));
return new Equal(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_NotEqual:
return new lite::NotEqual(const_cast<schema::Primitive *>(src_prim));
return new NotEqual(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Less:
return new lite::Less(const_cast<schema::Primitive *>(src_prim));
return new Less(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_LessEqual:
return new lite::LessEqual(const_cast<schema::Primitive *>(src_prim));
return new LessEqual(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Greater:
return new lite::Greater(const_cast<schema::Primitive *>(src_prim));
return new Greater(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_GreaterEqual:
return new lite::GreaterEqual(const_cast<schema::Primitive *>(src_prim));
return new GreaterEqual(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Floor:
return new lite::Floor(const_cast<schema::Primitive *>(src_prim));
return new Floor(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Ceil:
return new lite::Ceil(const_cast<schema::Primitive *>(src_prim));
return new Ceil(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Split:
return new lite::Split(const_cast<schema::Primitive *>(src_prim));
return new Split(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_OneHot:
return new lite::OneHot(const_cast<schema::Primitive *>(src_prim));
return new OneHot(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_SpaceToDepth:
return new lite::SpaceToDepth(const_cast<schema::Primitive *>(src_prim));
return new SpaceToDepth(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Tile:
return new lite::Tile(const_cast<schema::Primitive *>(src_prim));
return new Tile(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Resize:
return new lite::Resize(const_cast<schema::Primitive *>(src_prim));
return new Resize(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Unstack:
return new lite::Unstack(const_cast<schema::Primitive *>(src_prim));
return new Unstack(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Unique:
return new lite::Unique(const_cast<schema::Primitive *>(src_prim));
return new Unique(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_TopK:
return new lite::TopK(const_cast<schema::Primitive *>(src_prim));
return new TopK(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_MatMul:
return new lite::MatMul(const_cast<schema::Primitive *>(src_prim));
return new MatMul(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_QuantDTypeCast:
return new lite::QuantDTypeCast(const_cast<schema::Primitive *>(src_prim));
return new QuantDTypeCast(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_EmbeddingLookup:
return new lite::EmbeddingLookup(const_cast<schema::Primitive *>(src_prim));
return new EmbeddingLookup(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Elu:
return new lite::Elu(const_cast<schema::Primitive *>(src_prim));
return new Elu(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_DeDepthwiseConv2D:
return new lite::DeconvDepthwiseConv2D(const_cast<schema::Primitive *>(src_prim));
return new DeDepthwiseConv2D(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Shape:
return new lite::Shape(const_cast<schema::Primitive *>(src_prim));
return new Shape(const_cast<schema::Primitive *>(src_prim));
default:
break;
}
@ -334,9 +426,9 @@ Model *Model::Import(const char *model_buf, size_t size) {
Model::~Model() { delete (this->model_impl_); }
lite::Primitive *Model::GetOp(const std::string &name) const {
mindspore::lite::PrimitiveC *Model::GetOp(const std::string &name) const {
MS_EXCEPTION_IF_NULL(model_impl_);
return const_cast<Primitive *>(model_impl_->GetOp(name));
return const_cast<PrimitiveC *>(model_impl_->GetOp(name));
}
void Model::FreeMetaGraph() {

View File

@ -1,3 +1,3 @@
file(GLOB_RECURSE OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc)
file(GLOB_RECURSE C_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc)
add_library(ops_mid_ OBJECT ${OPS_SRC})
add_library(c_ops_mid OBJECT ${C_OPS_SRC})

View File

@ -1,7 +1,7 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the License);
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
@ -18,7 +18,7 @@
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "c_ops/arithmetic_self.h"
#include "src/ops/arithmetic_self.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
@ -29,14 +29,11 @@
#define LITE_MINDSPORE_LITE_C_OPS_ABS_H_
namespace mindspore {
namespace lite {
class Abs : public ArithmeticSelf {
public:
#ifdef PRIMITIVE_WRITEABLE
explicit Abs(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
#else
explicit Abs(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
#endif
explicit Abs(OriginPrimitive *primitive) : ArithmeticSelf(primitive) {}
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_ABS_H_

View File

@ -14,9 +14,10 @@
* limitations under the License.
*/
#include "c_ops/activation.h"
#include "src/ops/activation.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Activation::GetType() const { return this->primitive->value.AsActivation()->type; }
float Activation::GetAlpha() const { return this->primitive->value.AsActivation()->alpha; }
@ -32,4 +33,5 @@ float Activation::GetAlpha() const { return this->primitive->value_as_Activation
void Activation::SetType(int type) {}
void Activation::SetAlpha(float alpha) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -18,7 +18,7 @@
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "mindspore/lite/c_ops/primitive_c.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
@ -29,18 +29,16 @@
#define LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_H_
namespace mindspore {
namespace lite {
class Activation : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
explicit Activation(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
explicit Activation(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
explicit Activation(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
int GetType() const;
float GetAlpha() const;
void SetType(int type);
void SetAlpha(float alpha);
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_H_

View File

@ -14,9 +14,10 @@
* limitations under the License.
*/
#include "c_ops/activation_grad.h"
#include "src/ops/activation_grad.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int ActivationGrad::GetType() const { return this->primitive->value.AsActivationGrad()->type; }
@ -30,4 +31,5 @@ int ActivationGrad::GetType() const { return this->primitive->value_as_Activatio
void ActivationGrad::SetType(int type) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -18,7 +18,7 @@
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "mindspore/lite/c_ops/primitive_c.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
@ -29,16 +29,14 @@
#define LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_
namespace mindspore {
namespace lite {
class ActivationGrad : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
explicit ActivationGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
explicit ActivationGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
explicit ActivationGrad(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
int GetType() const;
void SetType(int type);
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_

View File

@ -14,9 +14,10 @@
* limitations under the License.
*/
#include "c_ops/add.h"
#include "src/ops/add.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Add::GetActivationType() const { return this->primitive->value.AsAdd()->activationType; }
@ -30,4 +31,5 @@ int Add::GetActivationType() const { return this->primitive->value_as_Add()->act
void Add::SetActivationType(int activation_type) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -18,7 +18,7 @@
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "c_ops/arithmetic.h"
#include "src/ops/arithmetic.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
@ -28,16 +28,15 @@
#define LITE_MINDSPORE_LITE_C_OPS_ADD_H_
namespace mindspore {
namespace lite {
class Add : public Arithmetic {
public:
#ifdef PRIMITIVE_WRITEABLE
explicit Add(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#else
explicit Add(schema::Primitive *primitive) : Arithmetic(primitive) {}
#endif
explicit Add(OriginPrimitive *primitive) : Arithmetic(primitive) {}
int GetActivationType() const;
void SetActivationType(int activation_type);
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_ADD_H_

View File

@ -14,12 +14,22 @@
* limitations under the License.
*/
#include "src/ops/ops.h"
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "src/ir/tensor.h"
#include "src/ops/addn.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int AddN::GetN() const { return this->primitive->value.AsAddN()->N; }
void AddN::SetN(int n) { this->primitive->value.AsAddN()->N = n; }
#else
int AddN::GetN() const { return this->primitive->value_as_AddN()->N(); }
void AddN::SetN(int n) {}
#endif
namespace mindspore::lite {
namespace {
constexpr int kLeastInputNum = 2;
}
@ -48,5 +58,5 @@ int AddN::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::T
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace mindspore::lite
} // namespace lite
} // namespace mindspore

View File

@ -18,7 +18,7 @@
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "mindspore/lite/c_ops/primitive_c.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
@ -29,17 +29,15 @@
#define LITE_MINDSPORE_LITE_C_OPS_ADD_N_H_
namespace mindspore {
namespace lite {
class AddN : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
explicit AddN(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
explicit AddN(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
explicit AddN(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetN() const;
void SetN(int n);
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_ADD_N_H_

View File

@ -14,12 +14,38 @@
* limitations under the License.
*/
#include "src/ops/ops.h"
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "src/ir/tensor.h"
#include "src/ops/argmax.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int ArgMax::GetAxis() const { return this->primitive->value.AsArgMax()->axis; }
bool ArgMax::GetOutMaxValue() const { return this->primitive->value.AsArgMax()->outMaxValue; }
int ArgMax::GetTopK() const { return this->primitive->value.AsArgMax()->topK; }
bool ArgMax::GetKeepDims() const { return this->primitive->value.AsArgMax()->keepDims; }
int ArgMax::GetAxisType() const { return this->primitive->value.AsArgMax()->axisType; }
void ArgMax::SetAxis(int axis) { this->primitive->value.AsArgMax()->axis = axis; }
void ArgMax::SetOutMaxValue(bool out_max_value) { this->primitive->value.AsArgMax()->outMaxValue = out_max_value; }
void ArgMax::SetTopK(int top_k) { this->primitive->value.AsArgMax()->topK = top_k; }
void ArgMax::SetKeepDims(bool keep_dims) { this->primitive->value.AsArgMax()->keepDims = keep_dims; }
void ArgMax::SetAxisType(int axis_type) { this->primitive->value.AsArgMax()->axisType = axis_type; }
#else
int ArgMax::GetAxis() const { return this->primitive->value_as_ArgMax()->axis(); }
bool ArgMax::GetOutMaxValue() const { return this->primitive->value_as_ArgMax()->outMaxValue(); }
int ArgMax::GetTopK() const { return this->primitive->value_as_ArgMax()->topK(); }
bool ArgMax::GetKeepDims() const { return this->primitive->value_as_ArgMax()->keepDims(); }
int ArgMax::GetAxisType() const { return this->primitive->value_as_ArgMax()->axisType(); }
void ArgMax::SetAxis(int axis) {}
void ArgMax::SetOutMaxValue(bool out_max_value) {}
void ArgMax::SetTopK(int top_k) {}
void ArgMax::SetKeepDims(bool keep_dims) {}
void ArgMax::SetAxisType(int axis_type) {}
#endif
namespace mindspore::lite {
int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
@ -30,7 +56,6 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_LOG(ERROR) << "tensor number is error.";
}
auto argmax_prim = this->primitive->value_as_ArgMax();
std::vector<int> output_shape(input->shape());
auto input_shape_size = input->shape().size();
int axis = argmax_prim->axis() < 0 ? argmax_prim->axis() + input_shape_size : argmax_prim->axis();
@ -43,11 +68,10 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
} else {
output_shape[axis] = argmax_prim->topK();
}
output->SetFormat(input->GetFormat());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace mindspore::lite
} // namespace lite
} // namespace mindspore

View File

@ -18,7 +18,7 @@
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "mindspore/lite/c_ops/primitive_c.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
@ -29,13 +29,11 @@
#define LITE_MINDSPORE_LITE_C_OPS_ARG_MAX_H_
namespace mindspore {
namespace lite {
class ArgMax : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
explicit ArgMax(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
explicit ArgMax(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
explicit ArgMax(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetAxis() const;
bool GetOutMaxValue() const;
@ -48,6 +46,7 @@ class ArgMax : public PrimitiveC {
void SetKeepDims(bool keep_dims);
void SetAxisType(int axis_type);
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_ARG_MAX_H_

View File

@ -14,12 +14,38 @@
* limitations under the License.
*/
#include "src/ops/ops.h"
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "src/ir/tensor.h"
#include "src/ops/argmin.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int ArgMin::GetAxis() const { return this->primitive->value.AsArgMin()->axis; }
bool ArgMin::GetOutMaxValue() const { return this->primitive->value.AsArgMin()->outMaxValue; }
int ArgMin::GetTopK() const { return this->primitive->value.AsArgMin()->topK; }
bool ArgMin::GetKeepDims() const { return this->primitive->value.AsArgMin()->keepDims; }
int ArgMin::GetAxisType() const { return this->primitive->value.AsArgMin()->axisType; }
void ArgMin::SetAxis(int axis) { this->primitive->value.AsArgMin()->axis = axis; }
void ArgMin::SetOutMaxValue(bool out_max_value) { this->primitive->value.AsArgMin()->outMaxValue = out_max_value; }
void ArgMin::SetTopK(int top_k) { this->primitive->value.AsArgMin()->topK = top_k; }
void ArgMin::SetKeepDims(bool keep_dims) { this->primitive->value.AsArgMin()->keepDims = keep_dims; }
void ArgMin::SetAxisType(int axis_type) { this->primitive->value.AsArgMin()->axisType = axis_type; }
#else
int ArgMin::GetAxis() const { return this->primitive->value_as_ArgMin()->axis(); }
bool ArgMin::GetOutMaxValue() const { return this->primitive->value_as_ArgMin()->outMaxValue(); }
int ArgMin::GetTopK() const { return this->primitive->value_as_ArgMin()->topK(); }
bool ArgMin::GetKeepDims() const { return this->primitive->value_as_ArgMin()->keepDims(); }
int ArgMin::GetAxisType() const { return this->primitive->value_as_ArgMin()->axisType(); }
void ArgMin::SetAxis(int axis) {}
void ArgMin::SetOutMaxValue(bool out_max_value) {}
void ArgMin::SetTopK(int top_k) {}
void ArgMin::SetKeepDims(bool keep_dims) {}
void ArgMin::SetAxisType(int axis_type) {}
#endif
namespace mindspore::lite {
int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
@ -42,10 +68,10 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
} else {
output_shape[axis] = argmin_prim->topK();
}
output->SetFormat(input->GetFormat());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace mindspore::lite
} // namespace lite
} // namespace mindspore

View File

@ -18,7 +18,7 @@
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "mindspore/lite/c_ops/primitive_c.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
@ -29,13 +29,11 @@
#define LITE_MINDSPORE_LITE_C_OPS_ARG_MIN_H_
namespace mindspore {
namespace lite {
class ArgMin : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
explicit ArgMin(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
explicit ArgMin(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
explicit ArgMin(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetAxis() const;
bool GetOutMaxValue() const;
@ -48,6 +46,7 @@ class ArgMin : public PrimitiveC {
void SetKeepDims(bool keep_dims);
void SetAxisType(int axis_type);
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_ARG_MIN_H_

View File

@ -14,13 +14,14 @@
* limitations under the License.
*/
#include "src/ops/ops.h"
#include "src/ops/arithmetic.h"
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "src/ir/tensor.h"
namespace mindspore::lite {
int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
namespace mindspore {
namespace lite {
int Arithmetic::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
if (inputs_.size() != kDoubleNum) {
MS_LOG(ERROR) << "The number of input must be " << kDoubleNum;
@ -103,5 +104,5 @@ int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
output->set_shape(output_shape);
return RET_OK;
}
} // namespace mindspore::lite
} // namespace lite
} // namespace mindspore

View File

@ -18,7 +18,7 @@
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "mindspore/lite/c_ops/primitive_c.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
@ -29,13 +29,11 @@
#define LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_H_
namespace mindspore {
namespace lite {
class Arithmetic : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
explicit Arithmetic(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
explicit Arithmetic(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
explicit Arithmetic(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
bool Broadcasting() { return this->broadcasting_; }
int NDims() { return this->ndim_; }
@ -50,6 +48,7 @@ class Arithmetic : public PrimitiveC {
std::vector<int> in_shape1_;
std::vector<int> out_shape_;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,12 +14,13 @@
* limitations under the License.
*/
#include "src/ops/ops.h"
#include "src/ops/arithmetic_self.h"
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "src/ir/tensor.h"
namespace mindspore::lite {
namespace mindspore {
namespace lite {
int ArithmeticSelf::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
@ -32,7 +33,7 @@ int ArithmeticSelf::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
return RET_OK;
}
output->set_shape(input->shape());
return RET_OK;
}
} // namespace mindspore::lite
} // namespace lite
} // namespace mindspore

View File

@ -18,7 +18,7 @@
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "mindspore/lite/c_ops/primitive_c.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
@ -28,15 +28,14 @@
#define LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_SELF_H_
namespace mindspore {
namespace lite {
class ArithmeticSelf : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
explicit ArithmeticSelf(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
explicit ArithmeticSelf(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
explicit ArithmeticSelf(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_SELF_H_

View File

@ -14,9 +14,10 @@
* limitations under the License.
*/
#include "c_ops/batch_norm.h"
#include "src/ops/batch_norm.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
float BatchNorm::GetEpsilon() const { return this->primitive->value.AsBatchNorm()->epsilon; }
@ -28,4 +29,5 @@ float BatchNorm::GetEpsilon() const { return this->primitive->value_as_BatchNorm
void BatchNorm::SetEpsilon(float epsilon) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -18,7 +18,7 @@
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "mindspore/lite/c_ops/primitive_c.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
@ -29,16 +29,15 @@
#define LITE_MINDSPORE_LITE_C_OPS_BATCH_NORM_H_
namespace mindspore {
namespace lite {
class BatchNorm : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
explicit BatchNorm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
explicit BatchNorm(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
explicit BatchNorm(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
float GetEpsilon() const;
void SetEpsilon(float epsilon);
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_BATCH_NORM_H_

View File

@ -14,12 +14,37 @@
* limitations under the License.
*/
#include "src/ops/ops.h"
#include "src/ops/batch_to_space.h"
#include "src/common/common.h"
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "src/ir/tensor.h"
namespace mindspore::lite {
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> BatchToSpace::GetBlockShape() const { return this->primitive->value.AsBatchToSpace()->blockShape; }
std::vector<int> BatchToSpace::GetCrops() const { return this->primitive->value.AsBatchToSpace()->crops; }
void BatchToSpace::SetBlockShape(const std::vector<int> &block_shape) {
this->primitive->value.AsBatchToSpace()->blockShape = block_shape;
}
void BatchToSpace::SetCrops(const std::vector<int> &crops) { this->primitive->value.AsBatchToSpace()->crops = crops; }
#else
std::vector<int> BatchToSpace::GetBlockShape() const {
auto fb_vector = this->primitive->value_as_BatchToSpace()->blockShape();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
std::vector<int> BatchToSpace::GetCrops() const {
auto fb_vector = this->primitive->value_as_BatchToSpace()->crops();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void BatchToSpace::SetBlockShape(const std::vector<int> &block_shape) {}
void BatchToSpace::SetCrops(const std::vector<int> &crops) {}
#endif
namespace {
constexpr int kBatchToSpaceOutputNum = 1;
constexpr int kBatchToSpaceInputNum = 1;
@ -27,7 +52,7 @@ constexpr int kBlockShapeSize = 2;
constexpr int kCropsSize = 4;
} // namespace
int BatchToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) {
int BatchToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
if (outputs.size() != kBatchToSpaceOutputNum || inputs.size() != kBatchToSpaceInputNum) {
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
@ -49,49 +74,50 @@ int BatchToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
return RET_PARAM_INVALID;
}
auto prim = this->primitive->value_as_BatchToSpace();
auto block_shape = prim->blockShape();
if (block_shape->size() != kBlockShapeSize) {
auto block_shape = GetBlockShape();
if (block_shape.size() != kBlockShapeSize) {
MS_LOG(ERROR) << "Block shape size should be " << kBlockShapeSize;
return RET_PARAM_INVALID;
}
auto crops = prim->crops();
if (crops->size() != kCropsSize) {
auto crops = GetCrops();
if (crops.size() != kCropsSize) {
MS_LOG(ERROR) << "Crops size should be " << kCropsSize;
return RET_PARAM_INVALID;
}
size_t mul_block_shape = 1;
for (size_t i = 0; i < kBlockShapeSize; ++i) {
if (block_shape->Get(i) <= 0) {
if (block_shape[i] <= 0) {
MS_LOG(ERROR) << "Input block_shape should > 0!";
return RET_PARAM_INVALID;
}
if (input_shape[kNHWC_n_index] % block_shape->Get(i)) {
MS_LOG(ERROR) << "Dimension n " << input_shape[kNHWC_n_index] << " can not divide block_shape[" << i << "] "
<< block_shape->Get(i);
return RET_PARAM_INVALID;
if (input_shape[NHWC_N] % block_shape[i]) {
MS_LOG(ERROR) << "Dimension n " << input_shape[NHWC_N] << " can not divide block_shape[" << i << "] "
<< block_shape[i];
return 1;
}
mul_block_shape *= block_shape->Get(i);
mul_block_shape *= block_shape[i];
}
if (input_shape[kNHWC_n_index] < mul_block_shape) {
MS_LOG(ERROR) << "Dimension n " << input_shape[kNHWC_n_index] << " < product of block shape!";
if (input_shape[NHWC_N] < mul_block_shape) {
MS_LOG(ERROR) << "Dimension n " << input_shape[NHWC_N] << " < product of block shape!";
return RET_PARAM_INVALID;
}
for (size_t i = 0; i < kCropsSize; ++i) {
if (crops->Get(i) < 0) {
if (crops[i] < 0) {
MS_LOG(ERROR) << "Input crops should >= 0";
return RET_PARAM_INVALID;
}
}
std::vector<int32_t> output_shape(input_shape.size());
output_shape[kNHWC_n_index] = input_shape[kNHWC_n_index] / mul_block_shape;
output_shape[kNHWC_h_index] = input_shape[kNHWC_h_index] * block_shape->Get(0) - crops->Get(0) - crops->Get(1);
output_shape[kNHWC_w_index] = input_shape[kNHWC_w_index] * block_shape->Get(1) - crops->Get(2) - crops->Get(3);
output_shape[kNHWC_c_index] = input_shape[kNHWC_c_index];
output_shape[NHWC_N] = input_shape[NHWC_N] / mul_block_shape;
output_shape[NHWC_H] = input_shape[NHWC_H] * block_shape[0] - crops[0] - crops[1];
output_shape[NHWC_W] = input_shape[NHWC_W] * block_shape[1] - crops[2] - crops[3];
output_shape[NHWC_C] = input_shape[NHWC_C];
outputs[0]->set_shape(output_shape);
return RET_OK;
}
} // namespace mindspore::lite
} // namespace lite
} // namespace mindspore

View File

@ -18,7 +18,7 @@
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "mindspore/lite/c_ops/primitive_c.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
@ -29,19 +29,18 @@
#define LITE_MINDSPORE_LITE_C_OPS_BATCH_TO_SPACE_H_
namespace mindspore {
namespace lite {
class BatchToSpace : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
explicit BatchToSpace(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
explicit BatchToSpace(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
explicit BatchToSpace(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
std::vector<int> GetBlockShape() const;
std::vector<int> GetCrops() const;
void SetBlockShape(const std::vector<int> &block_shape);
void SetCrops(const std::vector<int> &crops);
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_BATCH_TO_SPACE_H_

View File

@ -14,9 +14,10 @@
* limitations under the License.
*/
#include "c_ops/bias_add.h"
#include "src/ops/bias_add.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> BiasAdd::GetAxis() const { return this->primitive->value.AsBiasAdd()->axis; }
@ -31,4 +32,5 @@ std::vector<int> BiasAdd::GetAxis() const {
void BiasAdd::SetAxis(const std::vector<int> &axis) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -18,7 +18,7 @@
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "mindspore/lite/c_ops/primitive_c.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
@ -29,16 +29,15 @@
#define LITE_MINDSPORE_LITE_C_OPS_BIAS_ADD_H_
namespace mindspore {
namespace lite {
class BiasAdd : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
explicit BiasAdd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
explicit BiasAdd(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
explicit BiasAdd(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
std::vector<int> GetAxis() const;
void SetAxis(const std::vector<int> &axis);
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_BIAS_ADD_H_

View File

@ -14,9 +14,10 @@
* limitations under the License.
*/
#include "c_ops/bias_grad.h"
#include "src/ops/bias_grad.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> BiasGrad::GetAxis() const { return this->primitive->value.AsBiasGrad()->axis; }
@ -31,4 +32,5 @@ std::vector<int> BiasGrad::GetAxis() const {
void BiasGrad::SetAxis(const std::vector<int> &axis) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -18,7 +18,7 @@
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "mindspore/lite/c_ops/primitive_c.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
@ -29,16 +29,15 @@
#define LITE_MINDSPORE_LITE_C_OPS_BIAS_GRAD_H_
namespace mindspore {
namespace lite {
class BiasGrad : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
explicit BiasGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
explicit BiasGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
explicit BiasGrad(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
std::vector<int> GetAxis() const;
void SetAxis(const std::vector<int> &axis);
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_BIAS_GRAD_H_

View File

@ -14,9 +14,10 @@
* limitations under the License.
*/
#include "c_ops/bn_grad_input.h"
#include "src/ops/bn_grad_input.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
float BNGradInput::GetEps() const { return this->primitive->value.AsBNGradInput()->eps; }
int BNGradInput::GetChannels() const { return this->primitive->value.AsBNGradInput()->channels; }
@ -32,4 +33,5 @@ int BNGradInput::GetChannels() const { return this->primitive->value_as_BNGradIn
void BNGradInput::SetEps(float eps) {}
void BNGradInput::SetChannels(int channels) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -18,7 +18,7 @@
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "mindspore/lite/c_ops/primitive_c.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
@ -29,18 +29,17 @@
#define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_
namespace mindspore {
namespace lite {
class BNGradInput : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
explicit BNGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
explicit BNGradInput(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
explicit BNGradInput(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
float GetEps() const;
int GetChannels() const;
void SetEps(float eps);
void SetChannels(int channels);
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_

View File

@ -14,22 +14,36 @@
* limitations under the License.
*/
#include "src/ops/ops.h"
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "src/ir/tensor.h"
#include "src/ops/broadcast_to.h"
namespace mindspore::lite {
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> BroadcastTo::GetDstShape() const { return this->primitive->value.AsBroadcastTo()->dst_shape; }
void BroadcastTo::SetDstShape(const std::vector<int> &dst_shape) {
this->primitive->value.AsBroadcastTo()->dst_shape = dst_shape;
}
#else
std::vector<int> BroadcastTo::GetDstShape() const {
auto fb_vector = this->primitive->value_as_BroadcastTo()->dst_shape();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void BroadcastTo::SetDstShape(const std::vector<int> &dst_shape) {}
#endif
namespace {
constexpr int kBroadcastToInputNum = 1;
constexpr int kBroadcastToOutputNum = 1;
} // namespace
int BroadcastTo::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) {
int BroadcastTo::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
if (inputs.size() != kBroadcastToInputNum || outputs.size() != kBroadcastToOutputNum) {
MS_LOG(ERROR) << "input size:" << inputs.size() << ", output size:" << outputs.size();
return RET_PARAM_INVALID;
return 1;
}
auto input = inputs.at(0);
std::vector<int32_t> dst_shape(this->primitive->value_as_BroadcastTo()->dst_shape()->begin(),
@ -40,19 +54,19 @@ int BroadcastTo::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<te
if (input_shape.size() > dst_shape.size()) {
MS_LOG(ERROR) << "input shape size " << input_shape.size() << " should <= broadcast to shape size "
<< dst_shape.size() << "!";
return RET_PARAM_INVALID;
return 1;
}
for (int i = dst_shape.size() - 1; i >= 0; --i) {
if (dst_shape[i] < 0) {
MS_LOG(ERROR) << "shape[" << i << "] = " << dst_shape[i] << " ] should be > 0!";
return RET_PARAM_INVALID;
return 1;
}
if (input_shape_index >= 0) {
auto dim = input_shape[input_shape_index];
if (dim != dst_shape[i] && dim != 1) {
MS_LOG(ERROR) << "Invalid broadcast shape!";
return RET_PARAM_INVALID;
return 1;
}
}
shape[i] = dst_shape[i];
@ -61,6 +75,7 @@ int BroadcastTo::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<te
outputs[0]->SetFormat(input->GetFormat());
outputs[0]->set_shape(shape);
outputs[0]->set_data_type(input->data_type());
return RET_OK;
return 0;
}
} // namespace mindspore::lite
} // namespace lite
} // namespace mindspore

View File

@ -18,7 +18,7 @@
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "mindspore/lite/c_ops/primitive_c.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
@ -29,17 +29,16 @@
#define LITE_MINDSPORE_LITE_C_OPS_BROADCAST_TO_H_
namespace mindspore {
namespace lite {
class BroadcastTo : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
explicit BroadcastTo(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
explicit BroadcastTo(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
explicit BroadcastTo(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
std::vector<int> GetDstShape() const;
void SetDstShape(const std::vector<int> &dst_shape);
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_BROADCAST_TO_H_

View File

@ -14,9 +14,10 @@
* limitations under the License.
*/
#include "c_ops/caffe_p_relu.h"
#include "src/ops/caffe_p_relu.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
bool CaffePReLU::GetChannelShared() const { return this->primitive->value.AsCaffePReLU()->channelShared; }
@ -30,4 +31,5 @@ bool CaffePReLU::GetChannelShared() const { return this->primitive->value_as_Caf
void CaffePReLU::SetChannelShared(bool channel_shared) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -18,8 +18,8 @@
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "mindspore/lite/c_ops/primitive_c.h"
#include "c_ops/activation.h"
#include "src/ops/primitive_c.h"
#include "src/ops/activation.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
@ -30,16 +30,15 @@
#define LITE_MINDSPORE_LITE_C_OPS_CAFFE_P_RE_L_U_H_
namespace mindspore {
namespace lite {
class CaffePReLU : public Activation {
public:
#ifdef PRIMITIVE_WRITEABLE
explicit CaffePReLU(schema::PrimitiveT *primitive) : Activation(primitive) {}
#else
explicit CaffePReLU(schema::Primitive *primitive) : Activation(primitive) {}
#endif
explicit CaffePReLU(OriginPrimitive *primitive) : Activation(primitive) {}
bool GetChannelShared() const;
void SetChannelShared(bool channel_shared);
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_CAFFE_P_RE_L_U_H_

View File

@ -14,12 +14,26 @@
* limitations under the License.
*/
#include "src/ops/ops.h"
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "src/ir/tensor.h"
#include "src/ops/cast.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Cast::GetSrcT() const { return this->primitive->value.AsCast()->srcT; }
int Cast::GetDstT() const { return this->primitive->value.AsCast()->dstT; }
void Cast::SetSrcT(int src_t) { this->primitive->value.AsCast()->srcT = src_t; }
void Cast::SetDstT(int dst_t) { this->primitive->value.AsCast()->dstT = dst_t; }
#else
int Cast::GetSrcT() const { return this->primitive->value_as_Cast()->srcT(); }
int Cast::GetDstT() const { return this->primitive->value_as_Cast()->dstT(); }
void Cast::SetSrcT(int src_t) {}
void Cast::SetDstT(int dst_t) {}
#endif
namespace mindspore::lite {
int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
@ -49,4 +63,5 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
output->set_data_type(TypeId::kNumberTypeFloat32);
return RET_OK;
}
} // namespace mindspore::lite
} // namespace lite
} // namespace mindspore

View File

@ -18,7 +18,7 @@
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "mindspore/lite/c_ops/primitive_c.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
@ -29,19 +29,18 @@
#define LITE_MINDSPORE_LITE_C_OPS_CAST_H_
namespace mindspore {
namespace lite {
class Cast : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
explicit Cast(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
explicit Cast(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
explicit Cast(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetSrcT() const;
int GetDstT() const;
void SetSrcT(int src_t);
void SetDstT(int dst_t);
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_CAST_H_

View File

@ -0,0 +1,40 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/arithmetic_self.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif
#ifndef LITE_MINDSPORE_LITE_C_OPS_CEIL_H_
#define LITE_MINDSPORE_LITE_C_OPS_CEIL_H_
namespace mindspore {
namespace lite {
class Ceil : public ArithmeticSelf {
public:
explicit Ceil(OriginPrimitive *primitive) : ArithmeticSelf(primitive) {}
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_CEIL_H_

View File

@ -14,9 +14,10 @@
* limitations under the License.
*/
#include "c_ops/clip.h"
#include "src/ops/clip.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
float Clip::GetMax() const { return this->primitive->value.AsClip()->max; }
float Clip::GetMin() const { return this->primitive->value.AsClip()->min; }
@ -32,4 +33,5 @@ float Clip::GetMin() const { return this->primitive->value_as_Clip()->min(); }
void Clip::SetMax(float max) {}
void Clip::SetMin(float min) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -18,7 +18,7 @@
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "mindspore/lite/c_ops/primitive_c.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
@ -29,18 +29,17 @@
#define LITE_MINDSPORE_LITE_C_OPS_CLIP_H_
namespace mindspore {
namespace lite {
class Clip : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
explicit Clip(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
explicit Clip(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
explicit Clip(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
float GetMax() const;
float GetMin() const;
void SetMax(float max);
void SetMin(float min);
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_CLIP_H_

View File

@ -14,12 +14,28 @@
* limitations under the License.
*/
#include "src/ops/ops.h"
#include "src/ops/concat.h"
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "src/ir/tensor.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Concat::GetAxis() const { return this->primitive->value.AsConcat()->axis; }
int Concat::GetN() const { return this->primitive->value.AsConcat()->n; }
void Concat::SetAxis(int axis) { this->primitive->value.AsConcat()->axis = axis; }
void Concat::SetN(int n) { this->primitive->value.AsConcat()->n = n; }
#else
int Concat::GetAxis() const { return this->primitive->value_as_Concat()->axis(); }
int Concat::GetN() const { return this->primitive->value_as_Concat()->n(); }
void Concat::SetAxis(int axis) {}
void Concat::SetN(int n) {}
#endif
namespace mindspore::lite {
namespace {
constexpr int kConcatOutputNum = 1;
}
@ -47,7 +63,6 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_LOG(ERROR) << "Invalid axis: " << axis;
return RET_PARAM_INVALID;
}
auto input0_shape_without_axis = input0_shape;
input0_shape_without_axis.erase(input0_shape_without_axis.begin() + axis);
auto input0_data_type = inputs_.at(0)->data_type();
@ -58,7 +73,6 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_LOG(ERROR) << "All inputs should have the same data type!";
return RET_PARAM_INVALID;
}
if (inputs_.at(i)->GetFormat() != input0_format) {
MS_LOG(ERROR) << "All input format should be the same!";
return RET_PARAM_INVALID;
@ -81,4 +95,5 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
outputs_[0]->set_shape(output_shape);
return RET_OK;
}
} // namespace mindspore::lite
} // namespace lite
} // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More