forked from mindspore-Ecosystem/mindspore
!4628 change ops
Merge pull request !4628 from yeyunpeng2020/master_cops_3
This commit is contained in:
commit
b1cfb6d627
|
@ -1,3 +0,0 @@
|
|||
file(GLOB_RECURSE C_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc)
|
||||
|
||||
add_library(c_ops_mid OBJECT ${C_OPS_SRC})
|
|
@ -1,59 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/addn.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int AddN::GetN() const { return this->primitive->value.AsAddN()->N; }
|
||||
|
||||
void AddN::SetN(int n) { this->primitive->value.AsAddN()->N = n; }
|
||||
|
||||
#else
|
||||
|
||||
int AddN::GetN() const { return this->primitive->value_as_AddN()->N(); }
|
||||
|
||||
void AddN::SetN(int n) {}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr int kLeastInputNum = 2;
|
||||
}
|
||||
int AddN::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
if (inputs.size() < kLeastInputNum) {
|
||||
MS_LOG(ERROR) << "input size" << inputs.size() << " is error!";
|
||||
return 1;
|
||||
}
|
||||
for (int i = 1; i < inputs.size(); ++i) {
|
||||
if (inputs.at(i)->shape() != inputs.at(0)->shape()) {
|
||||
MS_LOG(ERROR) << "AddN inputs shape is not equal!";
|
||||
return 1;
|
||||
}
|
||||
if (inputs.at(i)->data_type() != inputs.at(0)->data_type()) {
|
||||
MS_LOG(ERROR) << "AddN all input data type should be the same!";
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
output->SetFormat(input->GetFormat());
|
||||
output->set_shape(input->shape());
|
||||
output->set_data_type(input->data_type());
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,75 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/argmax.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int ArgMax::GetAxis() const { return this->primitive->value.AsArgMax()->axis; }
|
||||
bool ArgMax::GetOutMaxValue() const { return this->primitive->value.AsArgMax()->outMaxValue; }
|
||||
int ArgMax::GetTopK() const { return this->primitive->value.AsArgMax()->topK; }
|
||||
bool ArgMax::GetKeepDims() const { return this->primitive->value.AsArgMax()->keepDims; }
|
||||
int ArgMax::GetAxisType() const { return this->primitive->value.AsArgMax()->axisType; }
|
||||
|
||||
void ArgMax::SetAxis(int axis) { this->primitive->value.AsArgMax()->axis = axis; }
|
||||
void ArgMax::SetOutMaxValue(bool out_max_value) { this->primitive->value.AsArgMax()->outMaxValue = out_max_value; }
|
||||
void ArgMax::SetTopK(int top_k) { this->primitive->value.AsArgMax()->topK = top_k; }
|
||||
void ArgMax::SetKeepDims(bool keep_dims) { this->primitive->value.AsArgMax()->keepDims = keep_dims; }
|
||||
void ArgMax::SetAxisType(int axis_type) { this->primitive->value.AsArgMax()->axisType = axis_type; }
|
||||
|
||||
#else
|
||||
|
||||
int ArgMax::GetAxis() const { return this->primitive->value_as_ArgMax()->axis(); }
|
||||
bool ArgMax::GetOutMaxValue() const { return this->primitive->value_as_ArgMax()->outMaxValue(); }
|
||||
int ArgMax::GetTopK() const { return this->primitive->value_as_ArgMax()->topK(); }
|
||||
bool ArgMax::GetKeepDims() const { return this->primitive->value_as_ArgMax()->keepDims(); }
|
||||
int ArgMax::GetAxisType() const { return this->primitive->value_as_ArgMax()->axisType(); }
|
||||
|
||||
void ArgMax::SetAxis(int axis) {}
|
||||
void ArgMax::SetOutMaxValue(bool out_max_value) {}
|
||||
void ArgMax::SetTopK(int top_k) {}
|
||||
void ArgMax::SetKeepDims(bool keep_dims) {}
|
||||
void ArgMax::SetAxisType(int axis_type) {}
|
||||
#endif
|
||||
int ArgMax::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
|
||||
MS_LOG(ERROR) << "tensor number is error.";
|
||||
}
|
||||
|
||||
std::vector<int> output_shape(input->shape());
|
||||
auto input_shape_size = input->shape().size();
|
||||
int axis = GetAxis() < 0 ? GetAxis() + input_shape_size : GetAxis();
|
||||
if (axis >= input_shape_size || axis < 0) {
|
||||
MS_LOG(ERROR) << "Invalid axis " << GetAxis() << ", input shape size: " << input_shape_size;
|
||||
return 1;
|
||||
}
|
||||
if (GetTopK() == 1 && !GetKeepDims()) {
|
||||
output_shape.erase(output_shape.begin() + axis);
|
||||
} else {
|
||||
output_shape[axis] = GetTopK();
|
||||
}
|
||||
|
||||
output->SetFormat(input->GetFormat());
|
||||
output->set_shape(output_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,74 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/argmin.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int ArgMin::GetAxis() const { return this->primitive->value.AsArgMin()->axis; }
|
||||
bool ArgMin::GetOutMaxValue() const { return this->primitive->value.AsArgMin()->outMaxValue; }
|
||||
int ArgMin::GetTopK() const { return this->primitive->value.AsArgMin()->topK; }
|
||||
bool ArgMin::GetKeepDims() const { return this->primitive->value.AsArgMin()->keepDims; }
|
||||
int ArgMin::GetAxisType() const { return this->primitive->value.AsArgMin()->axisType; }
|
||||
|
||||
void ArgMin::SetAxis(int axis) { this->primitive->value.AsArgMin()->axis = axis; }
|
||||
void ArgMin::SetOutMaxValue(bool out_max_value) { this->primitive->value.AsArgMin()->outMaxValue = out_max_value; }
|
||||
void ArgMin::SetTopK(int top_k) { this->primitive->value.AsArgMin()->topK = top_k; }
|
||||
void ArgMin::SetKeepDims(bool keep_dims) { this->primitive->value.AsArgMin()->keepDims = keep_dims; }
|
||||
void ArgMin::SetAxisType(int axis_type) { this->primitive->value.AsArgMin()->axisType = axis_type; }
|
||||
|
||||
#else
|
||||
|
||||
int ArgMin::GetAxis() const { return this->primitive->value_as_ArgMin()->axis(); }
|
||||
bool ArgMin::GetOutMaxValue() const { return this->primitive->value_as_ArgMin()->outMaxValue(); }
|
||||
int ArgMin::GetTopK() const { return this->primitive->value_as_ArgMin()->topK(); }
|
||||
bool ArgMin::GetKeepDims() const { return this->primitive->value_as_ArgMin()->keepDims(); }
|
||||
int ArgMin::GetAxisType() const { return this->primitive->value_as_ArgMin()->axisType(); }
|
||||
|
||||
void ArgMin::SetAxis(int axis) {}
|
||||
void ArgMin::SetOutMaxValue(bool out_max_value) {}
|
||||
void ArgMin::SetTopK(int top_k) {}
|
||||
void ArgMin::SetKeepDims(bool keep_dims) {}
|
||||
void ArgMin::SetAxisType(int axis_type) {}
|
||||
#endif
|
||||
int ArgMin::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
|
||||
MS_LOG(ERROR) << "tensor number is error.";
|
||||
}
|
||||
auto input_shape_size = input->shape().size();
|
||||
int axis = GetAxis() < 0 ? GetAxis() + input_shape_size : GetAxis();
|
||||
if (axis >= input_shape_size || axis < 0) {
|
||||
MS_LOG(ERROR) << "Invalid axis " << GetAxis() << ", input shape size: " << input_shape_size;
|
||||
return 1;
|
||||
}
|
||||
std::vector<int> output_shape(input->shape());
|
||||
if (GetTopK() == 1 && !GetKeepDims()) {
|
||||
output_shape.erase(output_shape.begin() + axis);
|
||||
} else {
|
||||
output_shape[axis] = GetTopK();
|
||||
}
|
||||
|
||||
output->SetFormat(input->GetFormat());
|
||||
output->set_shape(output_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,99 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the License);
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/arithmetic.h"
|
||||
|
||||
namespace mindspore {
|
||||
int Arithmetic::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (inputs_.size() != kDoubleNum) {
|
||||
MS_LOG(ERROR) << "The number of input must be " << kDoubleNum;
|
||||
return 1;
|
||||
}
|
||||
if (outputs_.size() != kSingleNum) {
|
||||
MS_LOG(ERROR) << "The number of output must be " << kSingleNum;
|
||||
return 1;
|
||||
}
|
||||
auto input0 = inputs_[0];
|
||||
MS_ASSERT(input0 != nullptr);
|
||||
auto input1 = inputs_[1];
|
||||
MS_ASSERT(input1 != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
auto input_shape0 = input0->shape();
|
||||
auto input_shape1 = input1->shape();
|
||||
auto format = input0->GetFormat();
|
||||
in_shape0_.resize(5);
|
||||
in_shape1_.resize(5);
|
||||
out_shape_.resize(5);
|
||||
|
||||
ndim_ = input_shape0.size();
|
||||
if (input_shape0.size() < input_shape1.size()) {
|
||||
ndim_ = input_shape1.size();
|
||||
auto fill_dim_num = input_shape1.size() - input_shape0.size();
|
||||
int j = 0;
|
||||
for (int i = 0; i < input_shape1.size(); i++) {
|
||||
if (i < fill_dim_num) {
|
||||
in_shape0_[i] = 1;
|
||||
} else {
|
||||
in_shape0_[i] = input_shape0[j++];
|
||||
}
|
||||
in_shape1_[i] = input_shape1[i];
|
||||
}
|
||||
format = input0->GetFormat();
|
||||
} else if (input_shape0.size() > input_shape1.size()) {
|
||||
ndim_ = input_shape0.size();
|
||||
auto fill_dim_num = input_shape0.size() - input_shape1.size();
|
||||
int j = 0;
|
||||
for (int i = 0; i < input_shape0.size(); i++) {
|
||||
if (i < fill_dim_num) {
|
||||
in_shape1_[i] = 1;
|
||||
} else {
|
||||
in_shape1_[i] = input_shape1[j++];
|
||||
}
|
||||
in_shape0_[i] = input_shape0[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < input_shape0.size(); i++) {
|
||||
in_shape1_[i] = input_shape1[i];
|
||||
in_shape0_[i] = input_shape0[i];
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> output_shape;
|
||||
for (size_t i = 0; i < ndim_; i++) {
|
||||
if (in_shape0_[i] != in_shape1_[i]) {
|
||||
if (in_shape0_[i] == 1) {
|
||||
out_shape_[i] = in_shape1_[i];
|
||||
} else if (in_shape1_[i] == 1) {
|
||||
out_shape_[i] = in_shape0_[i];
|
||||
} else {
|
||||
MS_LOG(ERROR) << "shapes of input tensors can not be broadCasted";
|
||||
return -1;
|
||||
}
|
||||
broadcasting_ = true;
|
||||
} else {
|
||||
out_shape_[i] = in_shape0_[i];
|
||||
}
|
||||
output_shape.push_back(out_shape_[i]);
|
||||
}
|
||||
output->SetFormat(format);
|
||||
output->set_shape(output_shape);
|
||||
output->set_data_type(input0->data_type());
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,34 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the License);
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/arithmetic_self.h"
|
||||
|
||||
namespace mindspore {
|
||||
int ArithmeticSelf::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
|
||||
std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
output->SetFormat(input->GetFormat());
|
||||
output->set_shape(input->shape());
|
||||
output->set_data_type(input->data_type());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,114 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/batch_to_space.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
std::vector<int> BatchToSpace::GetBlockShape() const { return this->primitive->value.AsBatchToSpace()->blockShape; }
|
||||
std::vector<int> BatchToSpace::GetCrops() const { return this->primitive->value.AsBatchToSpace()->crops; }
|
||||
|
||||
void BatchToSpace::SetBlockShape(const std::vector<int> &block_shape) {
|
||||
this->primitive->value.AsBatchToSpace()->blockShape = block_shape;
|
||||
}
|
||||
void BatchToSpace::SetCrops(const std::vector<int> &crops) { this->primitive->value.AsBatchToSpace()->crops = crops; }
|
||||
|
||||
#else
|
||||
|
||||
std::vector<int> BatchToSpace::GetBlockShape() const {
|
||||
auto fb_vector = this->primitive->value_as_BatchToSpace()->blockShape();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
std::vector<int> BatchToSpace::GetCrops() const {
|
||||
auto fb_vector = this->primitive->value_as_BatchToSpace()->crops();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
void BatchToSpace::SetBlockShape(const std::vector<int> &block_shape) {}
|
||||
void BatchToSpace::SetCrops(const std::vector<int> &crops) {}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr int kBatchToSpaceOutputNum = 1;
|
||||
constexpr int kBatchToSpaceInputNum = 1;
|
||||
constexpr int kBlockShapeSize = 2;
|
||||
constexpr int kCropsSize = 4;
|
||||
} // namespace
|
||||
|
||||
int BatchToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (outputs.size() != kBatchToSpaceOutputNum || inputs.size() != kBatchToSpaceInputNum) {
|
||||
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto input = inputs.at(0);
|
||||
if (input->GetFormat() != schema::Format_NHWC) {
|
||||
MS_LOG(ERROR) << "batch_to_space only support NHWC now!";
|
||||
return 1;
|
||||
}
|
||||
auto input_shape = input->shape();
|
||||
if (input_shape.size() != kDimension_4d) {
|
||||
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto block_shape = GetBlockShape();
|
||||
if (block_shape.size() != kBlockShapeSize) {
|
||||
MS_LOG(ERROR) << "Block shape size should be " << kBlockShapeSize;
|
||||
return 1;
|
||||
}
|
||||
auto crops = GetCrops();
|
||||
if (crops.size() != kCropsSize) {
|
||||
MS_LOG(ERROR) << "Crops size should be " << kCropsSize;
|
||||
return 1;
|
||||
}
|
||||
size_t mul_block_shape = 1;
|
||||
|
||||
for (size_t i = 0; i < kBlockShapeSize; ++i) {
|
||||
if (block_shape[i] <= 0) {
|
||||
MS_LOG(ERROR) << "Input block_shape should > 0!";
|
||||
return 1;
|
||||
}
|
||||
if (input_shape[NHWC_N] % block_shape[i]) {
|
||||
MS_LOG(ERROR) << "Dimension n " << input_shape[NHWC_N] << " can not divide block_shape[" << i << "] "
|
||||
<< block_shape[i];
|
||||
return 1;
|
||||
}
|
||||
mul_block_shape *= block_shape[i];
|
||||
}
|
||||
|
||||
if (input_shape[NHWC_N] < mul_block_shape) {
|
||||
MS_LOG(ERROR) << "Dimension n " << input_shape[NHWC_N] << " < product of block shape!";
|
||||
return 1;
|
||||
}
|
||||
for (size_t i = 0; i < kCropsSize; ++i) {
|
||||
if (crops[i] < 0) {
|
||||
MS_LOG(ERROR) << "Input crops should >= 0";
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
std::vector<int32_t> output_shape(input_shape.size());
|
||||
output_shape[NHWC_N] = input_shape[NHWC_N] / mul_block_shape;
|
||||
output_shape[NHWC_H] = input_shape[NHWC_H] * block_shape[0] - crops[0] - crops[1];
|
||||
output_shape[NHWC_W] = input_shape[NHWC_W] * block_shape[1] - crops[2] - crops[3];
|
||||
output_shape[NHWC_C] = input_shape[NHWC_C];
|
||||
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
outputs[0]->set_shape(output_shape);
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,79 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/broadcast_to.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
std::vector<int> BroadcastTo::GetDstShape() const { return this->primitive->value.AsBroadcastTo()->dst_shape; }
|
||||
|
||||
void BroadcastTo::SetDstShape(const std::vector<int> &dst_shape) {
|
||||
this->primitive->value.AsBroadcastTo()->dst_shape = dst_shape;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
std::vector<int> BroadcastTo::GetDstShape() const {
|
||||
auto fb_vector = this->primitive->value_as_BroadcastTo()->dst_shape();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
void BroadcastTo::SetDstShape(const std::vector<int> &dst_shape) {}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr int kBroadcastToInputNum = 1;
|
||||
constexpr int kBroadcastToOutputNum = 1;
|
||||
} // namespace
|
||||
|
||||
int BroadcastTo::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (inputs.size() != kBroadcastToInputNum || outputs.size() != kBroadcastToOutputNum) {
|
||||
MS_LOG(ERROR) << "input size:" << inputs.size() << ", output size:" << outputs.size();
|
||||
return 1;
|
||||
}
|
||||
auto input = inputs.at(0);
|
||||
std::vector<int32_t> dst_shape(this->primitive->value_as_BroadcastTo()->dst_shape()->begin(),
|
||||
this->primitive->value_as_BroadcastTo()->dst_shape()->end());
|
||||
auto input_shape = input->shape();
|
||||
std::vector<int> shape(dst_shape.size());
|
||||
int input_shape_index = input_shape.size() - 1;
|
||||
if (input_shape.size() > dst_shape.size()) {
|
||||
MS_LOG(ERROR) << "input shape size " << input_shape.size() << " should <= broadcast to shape size "
|
||||
<< dst_shape.size() << "!";
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (int i = dst_shape.size() - 1; i >= 0; --i) {
|
||||
if (dst_shape[i] < 0) {
|
||||
MS_LOG(ERROR) << "shape[" << i << "] = " << dst_shape[i] << " ] should be > 0!";
|
||||
return 1;
|
||||
}
|
||||
if (input_shape_index >= 0) {
|
||||
auto dim = input_shape[input_shape_index];
|
||||
if (dim != dst_shape[i] && dim != 1) {
|
||||
MS_LOG(ERROR) << "Invalid broadcast shape!";
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
shape[i] = dst_shape[i];
|
||||
--input_shape_index;
|
||||
}
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
outputs[0]->set_shape(shape);
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,64 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/cast.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Cast::GetSrcT() const { return this->primitive->value.AsCast()->srcT; }
|
||||
int Cast::GetDstT() const { return this->primitive->value.AsCast()->dstT; }
|
||||
|
||||
void Cast::SetSrcT(int src_t) { this->primitive->value.AsCast()->srcT = src_t; }
|
||||
void Cast::SetDstT(int dst_t) { this->primitive->value.AsCast()->dstT = dst_t; }
|
||||
|
||||
#else
|
||||
|
||||
int Cast::GetSrcT() const { return this->primitive->value_as_Cast()->srcT(); }
|
||||
int Cast::GetDstT() const { return this->primitive->value_as_Cast()->dstT(); }
|
||||
|
||||
void Cast::SetSrcT(int src_t) {}
|
||||
void Cast::SetDstT(int dst_t) {}
|
||||
#endif
|
||||
int Cast::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
|
||||
MS_LOG(ERROR) << "tensor number is error.";
|
||||
return 1;
|
||||
}
|
||||
|
||||
MS_ASSERT(cast_prim != nullptr);
|
||||
if (input->data_type() != GetSrcT()) {
|
||||
MS_LOG(ERROR) << "input dataType is error";
|
||||
return 1;
|
||||
}
|
||||
if (kSupportDataType.find(input->data_type()) == kSupportDataType.end()) {
|
||||
MS_LOG(ERROR) << "Unsupported input data type " << input->data_type();
|
||||
return 1;
|
||||
}
|
||||
if (GetDstT() != kNumberTypeFloat && GetDstT() != kNumberTypeFloat32) {
|
||||
MS_LOG(ERROR) << "Invalid output datatype " << GetDstT();
|
||||
return 1;
|
||||
}
|
||||
output->SetFormat(input->GetFormat());
|
||||
output->set_shape(input->shape());
|
||||
output->set_data_type(input->data_type());
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,93 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/concat.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Concat::GetAxis() const { return this->primitive->value.AsConcat()->axis; }
|
||||
int Concat::GetN() const { return this->primitive->value.AsConcat()->n; }
|
||||
|
||||
void Concat::SetAxis(int axis) { this->primitive->value.AsConcat()->axis = axis; }
|
||||
void Concat::SetN(int n) { this->primitive->value.AsConcat()->n = n; }
|
||||
|
||||
#else
|
||||
|
||||
int Concat::GetAxis() const { return this->primitive->value_as_Concat()->axis(); }
|
||||
int Concat::GetN() const { return this->primitive->value_as_Concat()->n(); }
|
||||
|
||||
void Concat::SetAxis(int axis) {}
|
||||
void Concat::SetN(int n) {}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr int kConcatOutputNum = 1;
|
||||
}
|
||||
int Concat::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
if (this->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is nullptr!";
|
||||
return 1;
|
||||
}
|
||||
auto input0 = inputs_.front();
|
||||
auto output = outputs_.front();
|
||||
if (outputs_.size() != kConcatOutputNum) {
|
||||
MS_LOG(ERROR) << "output size is error";
|
||||
return 1;
|
||||
}
|
||||
MS_ASSERT(concat_prim != nullptr);
|
||||
auto input0_shape = inputs_.at(0)->shape();
|
||||
int axis = GetAxis() < 0 ? GetAxis() + input0_shape.size() : GetAxis();
|
||||
if (axis < 0 || axis >= input0_shape.size()) {
|
||||
MS_LOG(ERROR) << "Invalid axis: " << axis;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto input0_shape_without_axis = input0_shape;
|
||||
input0_shape_without_axis.erase(input0_shape_without_axis.begin() + axis);
|
||||
auto input0_data_type = inputs_.at(0)->data_type();
|
||||
schema::Format input0_format = inputs_[0]->GetFormat();
|
||||
int output_axis_dim = input0_shape.at(axis);
|
||||
for (size_t i = 1; i < inputs_.size(); ++i) {
|
||||
if (inputs_.at(i)->data_type() != input0_data_type) {
|
||||
MS_LOG(ERROR) << "All inputs should have the same data type!";
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (inputs_.at(i)->GetFormat() != input0_format) {
|
||||
MS_LOG(ERROR) << "All input format should be the same!";
|
||||
return 1;
|
||||
}
|
||||
auto shape_tmp = inputs_.at(i)->shape();
|
||||
if (shape_tmp.size() != input0_shape.size()) {
|
||||
MS_LOG(ERROR) << "All inputs should have the same dim num!";
|
||||
return 1;
|
||||
}
|
||||
auto axis_tmp = shape_tmp[axis];
|
||||
shape_tmp.erase(shape_tmp.begin() + axis);
|
||||
if (input0_shape_without_axis != shape_tmp) {
|
||||
MS_LOG(ERROR) << "Inputs should have the same dim except axis!";
|
||||
return 1;
|
||||
}
|
||||
output_axis_dim += axis_tmp;
|
||||
}
|
||||
auto output_shape = input0_shape;
|
||||
output_shape[axis] = output_axis_dim;
|
||||
outputs_[0]->set_shape(output_shape);
|
||||
output->set_data_type(input0->data_type());
|
||||
output->SetFormat(input0->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,55 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/crop.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
long Crop::GetAxis() const { return this->primitive->value.AsCrop()->axis; }
|
||||
std::vector<long> Crop::GetOffsets() const { return this->primitive->value.AsCrop()->offsets; }
|
||||
|
||||
void Crop::SetAxis(long axis) { this->primitive->value.AsCrop()->axis = axis; }
|
||||
void Crop::SetOffsets(const std::vector<long> &offsets) { this->primitive->value.AsCrop()->offsets = offsets; }
|
||||
|
||||
#else
|
||||
|
||||
long Crop::GetAxis() const { return this->primitive->value_as_Crop()->axis(); }
|
||||
std::vector<long> Crop::GetOffsets() const {
|
||||
auto fb_vector = this->primitive->value_as_Crop()->offsets();
|
||||
return std::vector<long>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
void Crop::SetAxis(long axis) {}
|
||||
void Crop::SetOffsets(const std::vector<long> &offsets) {}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr int kCropOutputNum = 1;
|
||||
constexpr int kCropInputNum = 2;
|
||||
} // namespace
|
||||
|
||||
int Crop::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (outputs.size() != kCropOutputNum || inputs.size() != kCropInputNum) {
|
||||
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
|
||||
return 1;
|
||||
}
|
||||
outputs[0]->set_shape(inputs[1]->shape());
|
||||
outputs[0]->SetFormat(inputs[0]->GetFormat());
|
||||
outputs[0]->set_data_type(inputs[0]->data_type());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,75 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/depth_to_space.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int DepthToSpace::GetBlockSize() const { return this->primitive->value.AsDepthToSpace()->blockSize; }
|
||||
int DepthToSpace::GetFormat() const { return this->primitive->value.AsDepthToSpace()->format; }
|
||||
|
||||
void DepthToSpace::SetBlockSize(int block_size) { this->primitive->value.AsDepthToSpace()->blockSize = block_size; }
|
||||
void DepthToSpace::SetFormat(int format) { this->primitive->value.AsDepthToSpace()->format = format; }
|
||||
|
||||
#else
|
||||
|
||||
int DepthToSpace::GetBlockSize() const { return this->primitive->value_as_DepthToSpace()->blockSize(); }
|
||||
int DepthToSpace::GetFormat() const { return this->primitive->value_as_DepthToSpace()->format(); }
|
||||
|
||||
void DepthToSpace::SetBlockSize(int block_size) {}
|
||||
void DepthToSpace::SetFormat(int format) {}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr int kDepthToSpaceOutputNum = 1;
|
||||
constexpr int kDepthToSpaceInputNum = 1;
|
||||
} // namespace
|
||||
|
||||
int DepthToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (outputs.size() != kDepthToSpaceOutputNum || inputs.size() != kDepthToSpaceInputNum) {
|
||||
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto input = inputs.at(0);
|
||||
if (input->GetFormat() != schema::Format_NHWC) {
|
||||
MS_LOG(ERROR) << "depth_to_space only support NHWC now!";
|
||||
return 1;
|
||||
}
|
||||
auto input_shape = input->shape();
|
||||
if (input_shape.size() != kDimension_4d) {
|
||||
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
|
||||
return 1;
|
||||
}
|
||||
|
||||
int32_t block_size = GetBlockSize();
|
||||
if (input_shape[NHWC_C] % (block_size * block_size) != 0 || input_shape[NHWC_C] == 0) {
|
||||
MS_LOG(ERROR) << "input dimension c size " << input_shape[NHWC_C] << " should be mulitple of block_size("
|
||||
<< block_size << ") * block_size)!";
|
||||
return 1;
|
||||
}
|
||||
std::vector<int32_t> output_shape(input_shape.size());
|
||||
output_shape[NHWC_N] = input_shape[NHWC_N];
|
||||
output_shape[NHWC_H] = input_shape[NHWC_H] * block_size;
|
||||
output_shape[NHWC_W] = input_shape[NHWC_W] * block_size;
|
||||
output_shape[NHWC_C] = input_shape[NHWC_C] / (block_size * block_size);
|
||||
outputs[0]->set_shape(output_shape);
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,72 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/embedding_lookup.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
float EmbeddingLookup::GetMaxNorm() const { return this->primitive->value.AsEmbeddingLookup()->maxNorm; }
|
||||
|
||||
void EmbeddingLookup::SetMaxNorm(float max_norm) { this->primitive->value.AsEmbeddingLookup()->maxNorm = max_norm; }
|
||||
|
||||
#else
|
||||
|
||||
float EmbeddingLookup::GetMaxNorm() const { return this->primitive->value_as_EmbeddingLookup()->maxNorm(); }
|
||||
|
||||
void EmbeddingLookup::SetMaxNorm(float max_norm) {}
|
||||
#endif
|
||||
int EmbeddingLookup::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
|
||||
std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (inputs_.size() < kDoubleNum) {
|
||||
MS_LOG(ERROR) << "Embedding Lookup should have at least two inputs";
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (outputs_.size() != kSingleNum) {
|
||||
MS_LOG(ERROR) << "Embedding Lookup should have one outputs";
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto params_ = inputs_.front();
|
||||
MS_ASSERT(params_ != nullptr);
|
||||
auto ids = inputs_.back();
|
||||
MS_ASSERT(ids != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
auto embedding_shape = params_->shape();
|
||||
embedding_shape.erase(embedding_shape.begin());
|
||||
|
||||
std::vector<int> output_shape(ids->shape());
|
||||
for (size_t i = 0; i < embedding_shape.size(); ++i) {
|
||||
output_shape.push_back(embedding_shape.at(i));
|
||||
}
|
||||
|
||||
for (int i = 1; i < inputs_.size() - 1; ++i) {
|
||||
auto embedding_shape_t = inputs_.at(i)->shape();
|
||||
embedding_shape_t.erase(embedding_shape_t.begin());
|
||||
if (embedding_shape_t != embedding_shape) {
|
||||
MS_LOG(ERROR) << "The embedded layers should have the same shape";
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
output->set_shape(output_shape);
|
||||
output->set_data_type(params_->data_type());
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,60 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/expand_dims.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int ExpandDims::GetDim() const { return this->primitive->value.AsExpandDims()->dim; }
|
||||
|
||||
void ExpandDims::SetDim(int dim) { this->primitive->value.AsExpandDims()->dim = dim; }
|
||||
|
||||
#else
|
||||
|
||||
int ExpandDims::GetDim() const { return this->primitive->value_as_ExpandDims()->dim(); }
|
||||
|
||||
void ExpandDims::SetDim(int dim) {}
|
||||
#endif
|
||||
int ExpandDims::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
if (inputs_.size() != kSingleNum) {
|
||||
MS_LOG(ERROR) << "input size is invalid";
|
||||
}
|
||||
if (outputs_.size() != kSingleNum) {
|
||||
MS_LOG(ERROR) << "output size is invalid";
|
||||
}
|
||||
|
||||
int dim = GetDim();
|
||||
if (dim < 0) {
|
||||
dim += input->shape().size() + 1;
|
||||
}
|
||||
if (dim > input->shape().size()) {
|
||||
MS_LOG(ERROR) << "attribute dim out of range";
|
||||
return 1;
|
||||
}
|
||||
auto out_shape = input->shape();
|
||||
out_shape.insert(out_shape.begin() + dim, 1, 1);
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,56 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/fill.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
std::vector<int> Fill::GetDims() const { return this->primitive->value.AsFill()->dims; }
|
||||
|
||||
void Fill::SetDims(const std::vector<int> &dims) { this->primitive->value.AsFill()->dims = dims; }
|
||||
|
||||
#else
|
||||
|
||||
std::vector<int> Fill::GetDims() const {
|
||||
auto fb_vector = this->primitive->value_as_Fill()->dims();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
void Fill::SetDims(const std::vector<int> &dims) {}
|
||||
#endif
|
||||
int Fill::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
auto output = outputs_.front();
|
||||
if (input == nullptr || output == nullptr) {
|
||||
MS_LOG(ERROR) << "Fill input or output is null!";
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
|
||||
MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size();
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<int> output_shape;
|
||||
(void)output_shape.insert(output_shape.begin(), GetDims().begin(), GetDims().end());
|
||||
output->set_shape(output_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,47 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the License);
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/flatten.h"
|
||||
|
||||
namespace mindspore {
|
||||
int Flatten::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
auto output = outputs_.front();
|
||||
if (input == nullptr || output == nullptr) {
|
||||
MS_LOG(ERROR) << "Flatten input or output is null!";
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
|
||||
MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size();
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto input_shape = input->shape();
|
||||
std::vector<int> output_shape(2);
|
||||
output_shape[0] = input_shape[0];
|
||||
output_shape[1] = 1;
|
||||
for (int i = 1; i < input_shape.size(); i++) {
|
||||
output_shape[1] *= input_shape[i];
|
||||
}
|
||||
output->set_shape(output_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,87 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/gather.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Gather::GetAxis() const { return this->primitive->value.AsGather()->axis; }
|
||||
int Gather::GetBatchDims() const { return this->primitive->value.AsGather()->batchDims; }
|
||||
|
||||
void Gather::SetAxis(int axis) { this->primitive->value.AsGather()->axis = axis; }
|
||||
void Gather::SetBatchDims(int batch_dims) { this->primitive->value.AsGather()->batchDims = batch_dims; }
|
||||
|
||||
#else
|
||||
|
||||
int Gather::GetAxis() const { return this->primitive->value_as_Gather()->axis(); }
|
||||
int Gather::GetBatchDims() const { return this->primitive->value_as_Gather()->batchDims(); }
|
||||
|
||||
void Gather::SetAxis(int axis) {}
|
||||
void Gather::SetBatchDims(int batch_dims) {}
|
||||
#endif
|
||||
int Gather::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (inputs_.size() != kDoubleNum) {
|
||||
MS_LOG(ERROR) << "Gather should have two inputs";
|
||||
return 1;
|
||||
}
|
||||
if (outputs_.size() != kSingleNum) {
|
||||
MS_LOG(ERROR) << "Gather should have one outputs";
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto input = inputs_.at(0);
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto indices = inputs_.at(1);
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
|
||||
int axis = GetAxis();
|
||||
int batch_dims = GetBatchDims();
|
||||
if (axis < 0) {
|
||||
axis += input->shape().size();
|
||||
}
|
||||
auto indices_shape = indices->shape();
|
||||
int indices_rank = indices_shape.size();
|
||||
if (indices_rank < batch_dims + 1) {
|
||||
MS_LOG(ERROR) << "input[1]'s rank is less than batchDim + 1";
|
||||
return 1;
|
||||
}
|
||||
if (batch_dims != 0) {
|
||||
MS_LOG(ERROR) << "batchDims " << batch_dims << " != 0, which is not support";
|
||||
return 1;
|
||||
}
|
||||
auto in_shape = input->shape();
|
||||
int in_rank = in_shape.size();
|
||||
if (in_rank < axis + 1) {
|
||||
MS_LOG(ERROR) << "input[0]'s rank is less than axis + 1";
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<int> out_shape{in_shape};
|
||||
out_shape.erase(out_shape.begin() + axis);
|
||||
for (size_t i = 0; i < indices_rank; i++) {
|
||||
out_shape.insert(out_shape.begin() + axis, indices_shape[i]);
|
||||
}
|
||||
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,74 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/gather_nd.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int GatherNd::GetBatchDims() const { return this->primitive->value.AsGatherNd()->batchDims; }
|
||||
|
||||
void GatherNd::SetBatchDims(int batch_dims) { this->primitive->value.AsGatherNd()->batchDims = batch_dims; }
|
||||
|
||||
#else
|
||||
|
||||
int GatherNd::GetBatchDims() const { return this->primitive->value_as_GatherNd()->batchDims(); }
|
||||
|
||||
void GatherNd::SetBatchDims(int batch_dims) {}
|
||||
#endif
|
||||
int GatherNd::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (inputs_.size() != kDoubleNum) {
|
||||
MS_LOG(ERROR) << "GatherNd should have two inputs";
|
||||
return 1;
|
||||
}
|
||||
if (outputs_.size() != kSingleNum) {
|
||||
MS_LOG(ERROR) << "GatherNd should have one outputs";
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto input = inputs_.at(0);
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto indices = inputs_.at(1);
|
||||
MS_ASSERT(indices != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
auto in_shape = input->shape();
|
||||
int in_rank = in_shape.size();
|
||||
auto indices_shape = indices->shape();
|
||||
int indices_rank = indices_shape.size();
|
||||
|
||||
if (indices_shape[indices_rank - 1] > in_rank) {
|
||||
MS_LOG(ERROR) << "Input of indices data is error!";
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<int> out_shape;
|
||||
int i = 0;
|
||||
for (i = 0; i < indices_rank - 1; ++i) {
|
||||
out_shape.emplace_back(indices_shape[i]);
|
||||
}
|
||||
for (i = indices_shape[indices_rank - 1]; i < in_rank; ++i) {
|
||||
out_shape.emplace_back(in_shape[i]);
|
||||
}
|
||||
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,77 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/lstm.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
bool Lstm::GetBidirection() const { return this->primitive->value.AsLstm()->bidirection; }
|
||||
|
||||
void Lstm::SetBidirection(bool bidirection) { this->primitive->value.AsLstm()->bidirection = bidirection; }
|
||||
|
||||
#else
|
||||
|
||||
bool Lstm::GetBidirection() const { return this->primitive->value_as_Lstm()->bidirection(); }
|
||||
|
||||
void Lstm::SetBidirection(bool bidirection) {}
|
||||
#endif
|
||||
|
||||
const int kLstmInputNum = 6;
|
||||
const int kLstmOutputNum = 3;
|
||||
int Lstm::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (inputs_.size() != kLstmInputNum || outputs_.size() != kLstmOutputNum) {
|
||||
MS_LOG(ERROR) << "OpLstm inputs or outputs size error.";
|
||||
return 1;
|
||||
}
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto weight_i = inputs_.front();
|
||||
MS_ASSERT(input0 != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
std::vector<int> in_shape = input->shape();
|
||||
std::vector<int> w_shape = weight_i->shape(); // layer, hidden_size * 4, input_size
|
||||
if (in_shape.size() != 3 || w_shape.size() != 3) {
|
||||
MS_LOG(ERROR) << "OpLstm input dims should be 3.";
|
||||
return 1;
|
||||
}
|
||||
|
||||
int hidden_size = w_shape[1] / 4;
|
||||
|
||||
// set output
|
||||
std::vector<int> out_shape(in_shape);
|
||||
out_shape[2] = hidden_size;
|
||||
if (GetBidirection()) {
|
||||
out_shape.insert(out_shape.begin() + 1, 2);
|
||||
}
|
||||
output->set_shape(out_shape);
|
||||
|
||||
// set hidden state, cell state
|
||||
std::vector<int> state_shape(in_shape);
|
||||
state_shape[0] = GetBidirection() ? 2 : 1;
|
||||
state_shape[2] = hidden_size;
|
||||
outputs_[1]->set_shape(state_shape);
|
||||
outputs_[2]->set_shape(state_shape);
|
||||
|
||||
for (int i = 0; i < kLstmOutputNum; i++) {
|
||||
outputs_[i]->set_data_type(input->data_type());
|
||||
outputs_[i]->SetFormat(input->GetFormat());
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,77 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/matmul.h"
|
||||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
bool MatMul::GetTransposeA() const { return this->primitive->value.AsMatMul()->transposeA; }
|
||||
bool MatMul::GetTransposeB() const { return this->primitive->value.AsMatMul()->transposeB; }
|
||||
|
||||
void MatMul::SetTransposeA(bool transpose_a) { this->primitive->value.AsMatMul()->transposeA = transpose_a; }
|
||||
void MatMul::SetTransposeB(bool transpose_b) { this->primitive->value.AsMatMul()->transposeB = transpose_b; }
|
||||
|
||||
#else
|
||||
|
||||
bool MatMul::GetTransposeA() const { return this->primitive->value_as_MatMul()->transposeA(); }
|
||||
bool MatMul::GetTransposeB() const { return this->primitive->value_as_MatMul()->transposeB(); }
|
||||
|
||||
void MatMul::SetTransposeA(bool transpose_a) {}
|
||||
void MatMul::SetTransposeB(bool transpose_b) {}
|
||||
#endif
|
||||
int MatMul::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (inputs_.size() != kDoubleNum) {
|
||||
MS_LOG(ERROR) << "OpMatMul inputs size: " << inputs_.size();
|
||||
return 1;
|
||||
}
|
||||
auto input0 = inputs_.front();
|
||||
MS_ASSERT(input0 != nullptr);
|
||||
auto input1 = inputs_.at(1);
|
||||
MS_ASSERT(input1 != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
std::vector<int> a_shape = input0->shape();
|
||||
std::vector<int> b_shape = input1->shape();
|
||||
if (a_shape.size() < 2 || b_shape.size() < 2) {
|
||||
MS_LOG(ERROR) << "inputs shape is invalid";
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < a_shape.size() - 2; ++i) {
|
||||
if (a_shape[i] != b_shape[i]) {
|
||||
MS_LOG(ERROR) << "Op MatMul's dimensions must be equal";
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (GetTransposeA()) {
|
||||
std::swap(a_shape[a_shape.size() - 1], a_shape[a_shape.size() - 2]);
|
||||
}
|
||||
if (GetTransposeB()) {
|
||||
std::swap(b_shape[b_shape.size() - 1], b_shape[b_shape.size() - 2]);
|
||||
}
|
||||
std::vector<int> c_shape(a_shape);
|
||||
c_shape[c_shape.size() - 1] = b_shape[b_shape.size() - 1];
|
||||
output->set_shape(c_shape);
|
||||
output->set_data_type(input0->data_type());
|
||||
output->SetFormat(input0->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,94 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/mean.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
std::vector<int> Mean::GetAxis() const { return this->primitive->value.AsMean()->axis; }
|
||||
bool Mean::GetKeepDims() const { return this->primitive->value.AsMean()->keepDims; }
|
||||
|
||||
void Mean::SetAxis(const std::vector<int> &axis) { this->primitive->value.AsMean()->axis = axis; }
|
||||
void Mean::SetKeepDims(bool keep_dims) { this->primitive->value.AsMean()->keepDims = keep_dims; }
|
||||
|
||||
#else
|
||||
|
||||
std::vector<int> Mean::GetAxis() const {
|
||||
auto fb_vector = this->primitive->value_as_Mean()->axis();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
bool Mean::GetKeepDims() const { return this->primitive->value_as_Mean()->keepDims(); }
|
||||
|
||||
void Mean::SetAxis(const std::vector<int> &axis) {}
|
||||
void Mean::SetKeepDims(bool keep_dims) {}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr size_t kInputSize = 1;
|
||||
constexpr size_t kOutputSize = 1;
|
||||
} // namespace
|
||||
int Mean::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
if (inputs_.size() != kInputSize || outputs_.size() != kOutputSize) {
|
||||
return 1;
|
||||
}
|
||||
auto input = inputs_.front();
|
||||
auto output = outputs_.front();
|
||||
if (input == nullptr || output == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
if (this->primitive == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
bool keep_dims = static_cast<bool>(GetKeepDims());
|
||||
std::vector<int> in_shape = input->shape();
|
||||
std::vector<int> out_shape;
|
||||
const auto &axes = GetAxis();
|
||||
auto num_axes = axes.size();
|
||||
// reduce on all axes
|
||||
if (num_axes == 0) {
|
||||
if (keep_dims) {
|
||||
for (auto i = 0; i < in_shape.size(); i++) {
|
||||
out_shape.push_back(1);
|
||||
}
|
||||
}
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
return 0;
|
||||
}
|
||||
|
||||
// reduce on selected axes
|
||||
for (size_t i = 0; i < in_shape.size(); i++) {
|
||||
bool reduce_axis = false;
|
||||
for (int idx = 0; idx < num_axes; ++idx) {
|
||||
if (static_cast<size_t>(axes[idx]) == i) {
|
||||
reduce_axis = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (reduce_axis) {
|
||||
if (keep_dims) {
|
||||
out_shape.push_back(1);
|
||||
}
|
||||
} else {
|
||||
out_shape.push_back(in_shape[i]);
|
||||
}
|
||||
}
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,41 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the License);
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/nchw2nhwc.h"
|
||||
|
||||
namespace mindspore {
|
||||
int Nchw2Nhwc::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
std::vector<int> nchw_shape = input->shape();
|
||||
if (nchw_shape.size() != 4) {
|
||||
output->set_shape(nchw_shape);
|
||||
} else {
|
||||
std::vector<int> nhwc_shape{nchw_shape};
|
||||
nhwc_shape[NHWC_N] = nchw_shape[NCHW_N];
|
||||
nhwc_shape[NHWC_H] = nchw_shape[NCHW_H];
|
||||
nhwc_shape[NHWC_W] = nchw_shape[NCHW_W];
|
||||
nhwc_shape[NHWC_C] = nchw_shape[NCHW_C];
|
||||
output->set_shape(nhwc_shape);
|
||||
}
|
||||
output->SetFormat(schema::Format_NHWC);
|
||||
output->set_data_type(input->data_type());
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,41 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the License);
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/nhwc2nchw.h"
|
||||
|
||||
namespace mindspore {
|
||||
int Nhwc2Nchw::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
std::vector<int> nhwc_shape = input->shape();
|
||||
if (nhwc_shape.size() != 4) {
|
||||
output->set_shape(nhwc_shape);
|
||||
} else {
|
||||
std::vector<int> nchw_shape{nhwc_shape};
|
||||
nchw_shape[NCHW_N] = nhwc_shape[NHWC_N];
|
||||
nchw_shape[NCHW_C] = nhwc_shape[NHWC_C];
|
||||
nchw_shape[NCHW_H] = nhwc_shape[NHWC_H];
|
||||
nchw_shape[NCHW_W] = nhwc_shape[NHWC_W];
|
||||
output->set_shape(nchw_shape);
|
||||
}
|
||||
output->SetFormat(schema::Format_NCHW);
|
||||
output->set_data_type(input->data_type());
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,79 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/one_hot.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int OneHot::GetAxis() const { return this->primitive->value.AsOneHot()->axis; }
|
||||
|
||||
void OneHot::SetAxis(int axis) { this->primitive->value.AsOneHot()->axis = axis; }
|
||||
|
||||
#else
|
||||
|
||||
int OneHot::GetAxis() const { return this->primitive->value_as_OneHot()->axis(); }
|
||||
|
||||
void OneHot::SetAxis(int axis) {}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr size_t kOneHotInputNum = 4;
|
||||
}
|
||||
int OneHot::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
|
||||
if (this->primitive == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int axis = GetAxis();
|
||||
|
||||
// indices, depth, on_value, off_value
|
||||
if (inputs.size() != kOneHotInputNum) {
|
||||
MS_LOG(ERROR) << "OneHot got inputs num " << inputs.size() << ", should be " << kOneHotInputNum;
|
||||
return 1;
|
||||
}
|
||||
auto depth_tensor = inputs.at(1);
|
||||
if (depth_tensor == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
const int *depth = static_cast<int *>(depth_tensor->Data());
|
||||
|
||||
auto input = inputs.front();
|
||||
if (input == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
const auto input_shape = input->shape();
|
||||
int input_rank = static_cast<int>(input_shape.size());
|
||||
if (axis < 0) {
|
||||
axis += input_rank + 1;
|
||||
}
|
||||
std::vector<int> output_shape(input_shape);
|
||||
output_shape.insert(output_shape.cbegin() + axis, *depth);
|
||||
|
||||
auto output = outputs.front();
|
||||
if (output == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
output->set_shape(output_shape);
|
||||
|
||||
auto on_value = inputs.at(2);
|
||||
if (on_value == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
output->set_data_type(on_value->data_type());
|
||||
output->SetFormat(on_value->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,76 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/pad.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
std::vector<int> Pad::GetPaddings() const { return this->primitive->value.AsPad()->paddings; }
|
||||
int Pad::GetPaddingMode() const { return this->primitive->value.AsPad()->paddingMode; }
|
||||
float Pad::GetConstantValue() const { return this->primitive->value.AsPad()->constantValue; }
|
||||
|
||||
void Pad::SetPaddings(const std::vector<int> &paddings) { this->primitive->value.AsPad()->paddings = paddings; }
|
||||
void Pad::SetPaddingMode(int padding_mode) { this->primitive->value.AsPad()->paddingMode = padding_mode; }
|
||||
void Pad::SetConstantValue(float constant_value) { this->primitive->value.AsPad()->constantValue = constant_value; }
|
||||
|
||||
#else
|
||||
|
||||
std::vector<int> Pad::GetPaddings() const {
|
||||
auto fb_vector = this->primitive->value_as_Pad()->paddings();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
int Pad::GetPaddingMode() const { return this->primitive->value_as_Pad()->paddingMode(); }
|
||||
float Pad::GetConstantValue() const { return this->primitive->value_as_Pad()->constantValue(); }
|
||||
|
||||
void Pad::SetPaddings(const std::vector<int> &paddings) {}
|
||||
void Pad::SetPaddingMode(int padding_mode) {}
|
||||
void Pad::SetConstantValue(float constant_value) {}
|
||||
#endif
|
||||
namespace {
|
||||
const size_t kPaddingsSize = 8;
|
||||
const size_t kInputRank = 4;
|
||||
} // namespace
|
||||
int Pad::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (this->primitive == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto paddings = GetPaddings();
|
||||
|
||||
auto input = inputs.front();
|
||||
if (input == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
auto input_shape = input->shape();
|
||||
std::vector<int> output_shape;
|
||||
MS_ASSERT(input->shape().size() <= kInputRank);
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
auto paddings_index = i + kInputRank - input_shape.size();
|
||||
auto shape = input_shape[i] + (paddings)[2 * paddings_index] + (paddings)[2 * paddings_index + 1];
|
||||
output_shape.push_back(shape);
|
||||
}
|
||||
|
||||
auto output = outputs.front();
|
||||
if (output == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
output->SetFormat(input->GetFormat());
|
||||
output->set_shape(output_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,139 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/pooling.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Pooling::GetFormat() const { return this->primitive->value.AsPooling()->format; }
|
||||
int Pooling::GetPoolingMode() const { return this->primitive->value.AsPooling()->poolingMode; }
|
||||
bool Pooling::GetGlobal() const { return this->primitive->value.AsPooling()->global; }
|
||||
int Pooling::GetWindowW() const { return this->primitive->value.AsPooling()->windowW; }
|
||||
int Pooling::GetWindowH() const { return this->primitive->value.AsPooling()->windowH; }
|
||||
int Pooling::GetStrideW() const { return this->primitive->value.AsPooling()->strideW; }
|
||||
int Pooling::GetStrideH() const { return this->primitive->value.AsPooling()->strideH; }
|
||||
int Pooling::GetPadMode() const { return this->primitive->value.AsPooling()->padMode; }
|
||||
int Pooling::GetPadUp() const { return this->primitive->value.AsPooling()->padUp; }
|
||||
int Pooling::GetPadDown() const { return this->primitive->value.AsPooling()->padDown; }
|
||||
int Pooling::GetPadLeft() const { return this->primitive->value.AsPooling()->padLeft; }
|
||||
int Pooling::GetPadRight() const { return this->primitive->value.AsPooling()->padRight; }
|
||||
int Pooling::GetRoundMode() const { return this->primitive->value.AsPooling()->roundMode; }
|
||||
|
||||
void Pooling::SetFormat(int format) { this->primitive->value.AsPooling()->format = (schema::Format)format; }
|
||||
void Pooling::SetPoolingMode(int pooling_mode) {
|
||||
this->primitive->value.AsPooling()->poolingMode = (schema::PoolMode)pooling_mode;
|
||||
}
|
||||
void Pooling::SetGlobal(bool global) { this->primitive->value.AsPooling()->global = global; }
|
||||
void Pooling::SetWindowW(int window_w) { this->primitive->value.AsPooling()->windowW = window_w; }
|
||||
void Pooling::SetWindowH(int window_h) { this->primitive->value.AsPooling()->windowH = window_h; }
|
||||
void Pooling::SetStrideW(int stride_w) { this->primitive->value.AsPooling()->strideW = stride_w; }
|
||||
void Pooling::SetStrideH(int stride_h) { this->primitive->value.AsPooling()->strideH = stride_h; }
|
||||
void Pooling::SetPadMode(int pad_mode) { this->primitive->value.AsPooling()->padMode = (schema::PadMode)pad_mode; }
|
||||
void Pooling::SetPadUp(int pad_up) { this->primitive->value.AsPooling()->padUp = pad_up; }
|
||||
void Pooling::SetPadDown(int pad_down) { this->primitive->value.AsPooling()->padDown = pad_down; }
|
||||
void Pooling::SetPadLeft(int pad_left) { this->primitive->value.AsPooling()->padLeft = pad_left; }
|
||||
void Pooling::SetPadRight(int pad_right) { this->primitive->value.AsPooling()->padRight = pad_right; }
|
||||
void Pooling::SetRoundMode(int round_mode) {
|
||||
this->primitive->value.AsPooling()->roundMode = (schema::RoundMode)round_mode;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
int Pooling::GetFormat() const { return this->primitive->value_as_Pooling()->format(); }
|
||||
int Pooling::GetPoolingMode() const { return this->primitive->value_as_Pooling()->poolingMode(); }
|
||||
bool Pooling::GetGlobal() const { return this->primitive->value_as_Pooling()->global(); }
|
||||
int Pooling::GetWindowW() const { return this->primitive->value_as_Pooling()->windowW(); }
|
||||
int Pooling::GetWindowH() const { return this->primitive->value_as_Pooling()->windowH(); }
|
||||
int Pooling::GetStrideW() const { return this->primitive->value_as_Pooling()->strideW(); }
|
||||
int Pooling::GetStrideH() const { return this->primitive->value_as_Pooling()->strideH(); }
|
||||
int Pooling::GetPadMode() const { return this->primitive->value_as_Pooling()->padMode(); }
|
||||
int Pooling::GetPadUp() const { return this->primitive->value_as_Pooling()->padUp(); }
|
||||
int Pooling::GetPadDown() const { return this->primitive->value_as_Pooling()->padDown(); }
|
||||
int Pooling::GetPadLeft() const { return this->primitive->value_as_Pooling()->padLeft(); }
|
||||
int Pooling::GetPadRight() const { return this->primitive->value_as_Pooling()->padRight(); }
|
||||
int Pooling::GetRoundMode() const { return this->primitive->value_as_Pooling()->roundMode(); }
|
||||
|
||||
void Pooling::SetFormat(int format) {}
|
||||
void Pooling::SetPoolingMode(int pooling_mode) {}
|
||||
void Pooling::SetGlobal(bool global) {}
|
||||
void Pooling::SetWindowW(int window_w) {}
|
||||
void Pooling::SetWindowH(int window_h) {}
|
||||
void Pooling::SetStrideW(int stride_w) {}
|
||||
void Pooling::SetStrideH(int stride_h) {}
|
||||
void Pooling::SetPadMode(int pad_mode) {}
|
||||
void Pooling::SetPadUp(int pad_up) {}
|
||||
void Pooling::SetPadDown(int pad_down) {}
|
||||
void Pooling::SetPadLeft(int pad_left) {}
|
||||
void Pooling::SetPadRight(int pad_right) {}
|
||||
void Pooling::SetRoundMode(int round_mode) {}
|
||||
#endif
|
||||
int Pooling::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
int input_h = input->shape().at(1);
|
||||
int input_w = input->shape().at(2);
|
||||
|
||||
MS_ASSERT(pooling_prim != nullptr);
|
||||
auto window_h = GetWindowH();
|
||||
auto window_w = GetWindowH();
|
||||
if (GetGlobal()) {
|
||||
window_h = input_h;
|
||||
window_w = input_w;
|
||||
}
|
||||
|
||||
int output_h = 0;
|
||||
int output_w = 0;
|
||||
pad_l_ = GetPadLeft();
|
||||
pad_u_ = GetPadUp();
|
||||
pad_d_ = GetPadDown();
|
||||
pad_r_ = GetPadRight();
|
||||
if ((schema::PadMode)GetPadMode() == schema::PadMode_SAME) {
|
||||
output_w = std::ceil(static_cast<float>(input_w) / static_cast<float>(GetStrideW()));
|
||||
output_h = std::ceil(static_cast<float>(input_h) / static_cast<float>(GetStrideH()));
|
||||
auto pad_h_all = ((output_h - 1) * GetStrideH() + (window_h - 1) + 1 - input_h);
|
||||
auto pad_w_all = ((output_w - 1) * GetStrideW() + (window_w - 1) + 1 - input_w);
|
||||
pad_u_ = pad_h_all / 2;
|
||||
pad_d_ = pad_h_all - pad_u_;
|
||||
pad_l_ = pad_w_all / 2;
|
||||
pad_r_ = pad_w_all - pad_l_;
|
||||
} else {
|
||||
auto round_mode = GetRoundMode();
|
||||
if (round_mode == schema::RoundMode_FLOOR) {
|
||||
output_h = std::floor(static_cast<float>(input_h + pad_u_ + pad_d_ - window_h) / GetStrideH()) + 1;
|
||||
output_w = std::floor(static_cast<float>(input_w + pad_l_ + pad_r_ - window_w) / GetStrideW()) + 1;
|
||||
} else if (round_mode == schema::RoundMode_CEIL) {
|
||||
output_h = std::ceil(static_cast<float>(input_h + pad_u_ + pad_d_ - window_h) / GetStrideH()) + 1;
|
||||
output_w = std::ceil(static_cast<float>(input_w + pad_l_ + pad_r_ - window_w) / GetStrideW()) + 1;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unsupported round mode.";
|
||||
}
|
||||
}
|
||||
|
||||
// todo: fmk type
|
||||
auto input_shape = input->shape();
|
||||
input_shape.at(1) = output_h;
|
||||
input_shape.at(2) = output_w;
|
||||
output->set_shape(input_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
|
||||
// todo: temp fix
|
||||
output->SetFormat(schema::Format_NHWC);
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,62 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/power.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
float Power::GetPower() const { return this->primitive->value.AsPower()->power; }
|
||||
float Power::GetScale() const { return this->primitive->value.AsPower()->scale; }
|
||||
float Power::GetShift() const { return this->primitive->value.AsPower()->shift; }
|
||||
|
||||
void Power::SetPower(float power) { this->primitive->value.AsPower()->power = power; }
|
||||
void Power::SetScale(float scale) { this->primitive->value.AsPower()->scale = scale; }
|
||||
void Power::SetShift(float shift) { this->primitive->value.AsPower()->shift = shift; }
|
||||
|
||||
#else
|
||||
|
||||
float Power::GetPower() const { return this->primitive->value_as_Power()->power(); }
|
||||
float Power::GetScale() const { return this->primitive->value_as_Power()->scale(); }
|
||||
float Power::GetShift() const { return this->primitive->value_as_Power()->shift(); }
|
||||
|
||||
void Power::SetPower(float power) {}
|
||||
void Power::SetScale(float scale) {}
|
||||
void Power::SetShift(float shift) {}
|
||||
#endif
|
||||
int Power::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto x_tensor = inputs[0];
|
||||
MS_ASSERT(x_tensor != nullptr);
|
||||
lite::tensor::Tensor *exp_tensor = nullptr;
|
||||
if (inputs.size() == 2) {
|
||||
exp_tensor = inputs[1];
|
||||
MS_ASSERT(exp_tensor != nullptr);
|
||||
}
|
||||
auto output_tensor = outputs[0];
|
||||
MS_ASSERT(output_tensor != nullptr);
|
||||
if (exp_tensor != nullptr) {
|
||||
if (exp_tensor->shape() != x_tensor->shape() || exp_tensor->data_type() != x_tensor->data_type()) {
|
||||
MS_LOG(ERROR) << "Power inputs shape or type is not equal!";
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
output_tensor->SetFormat(x_tensor->GetFormat());
|
||||
output_tensor->set_shape(x_tensor->shape());
|
||||
output_tensor->set_data_type(x_tensor->data_type());
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,127 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/prior_box.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
std::vector<int> PriorBox::GetMinSizes() const { return this->primitive->value.AsPriorBox()->max_sizes; }
|
||||
std::vector<int> PriorBox::GetMaxSizes() const { return this->primitive->value.AsPriorBox()->max_sizes; }
|
||||
std::vector<float> PriorBox::GetAspectRatios() const { return this->primitive->value.AsPriorBox()->aspect_ratios; }
|
||||
std::vector<float> PriorBox::GetVariances() const { return this->primitive->value.AsPriorBox()->variances; }
|
||||
int PriorBox::GetImageSizeW() const { return this->primitive->value.AsPriorBox()->image_size_w; }
|
||||
int PriorBox::GetImageSizeH() const { return this->primitive->value.AsPriorBox()->image_size_h; }
|
||||
float PriorBox::GetStepW() const { return this->primitive->value.AsPriorBox()->step_w; }
|
||||
float PriorBox::GetStepH() const { return this->primitive->value.AsPriorBox()->step_h; }
|
||||
bool PriorBox::GetClip() const { return this->primitive->value.AsPriorBox()->clip; }
|
||||
bool PriorBox::GetFlip() const { return this->primitive->value.AsPriorBox()->flip; }
|
||||
float PriorBox::GetOffset() const { return this->primitive->value.AsPriorBox()->offset; }
|
||||
|
||||
void PriorBox::SetMinSizes(const std::vector<int> &min_sizes) {
|
||||
this->primitive->value.AsPriorBox()->min_sizes = min_sizes;
|
||||
}
|
||||
void PriorBox::SetMaxSizes(const std::vector<int> &max_sizes) {
|
||||
this->primitive->value.AsPriorBox()->max_sizes = max_sizes;
|
||||
}
|
||||
void PriorBox::SetAspectRatios(const std::vector<float> &aspect_ratios) {
|
||||
this->primitive->value.AsPriorBox()->aspect_ratios = aspect_ratios;
|
||||
}
|
||||
void PriorBox::SetVariances(const std::vector<float> &variances) {
|
||||
this->primitive->value.AsPriorBox()->variances = variances;
|
||||
}
|
||||
void PriorBox::SetImageSizeW(int image_size_w) { this->primitive->value.AsPriorBox()->image_size_w = image_size_w; }
|
||||
void PriorBox::SetImageSizeH(int image_size_h) { this->primitive->value.AsPriorBox()->image_size_h = image_size_h; }
|
||||
void PriorBox::SetStepW(float step_w) { this->primitive->value.AsPriorBox()->step_w = step_w; }
|
||||
void PriorBox::SetStepH(float step_h) { this->primitive->value.AsPriorBox()->step_h = step_h; }
|
||||
void PriorBox::SetClip(bool clip) { this->primitive->value.AsPriorBox()->clip = clip; }
|
||||
void PriorBox::SetFlip(bool flip) { this->primitive->value.AsPriorBox()->flip = flip; }
|
||||
void PriorBox::SetOffset(float offset) { this->primitive->value.AsPriorBox()->offset = offset; }
|
||||
|
||||
#else
|
||||
|
||||
std::vector<int> PriorBox::GetMinSizes() const {
|
||||
auto fb_vector = this->primitive->value_as_PriorBox()->min_sizes();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
std::vector<int> PriorBox::GetMaxSizes() const {
|
||||
auto fb_vector = this->primitive->value_as_PriorBox()->max_sizes();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
std::vector<float> PriorBox::GetAspectRatios() const {
|
||||
auto fb_vector = this->primitive->value_as_PriorBox()->aspect_ratios();
|
||||
return std::vector<float>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
std::vector<float> PriorBox::GetVariances() const {
|
||||
auto fb_vector = this->primitive->value_as_PriorBox()->variances();
|
||||
return std::vector<float>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
int PriorBox::GetImageSizeW() const { return this->primitive->value_as_PriorBox()->image_size_w(); }
|
||||
int PriorBox::GetImageSizeH() const { return this->primitive->value_as_PriorBox()->image_size_h(); }
|
||||
float PriorBox::GetStepW() const { return this->primitive->value_as_PriorBox()->step_w(); }
|
||||
float PriorBox::GetStepH() const { return this->primitive->value_as_PriorBox()->step_h(); }
|
||||
bool PriorBox::GetClip() const { return this->primitive->value_as_PriorBox()->clip(); }
|
||||
bool PriorBox::GetFlip() const { return this->primitive->value_as_PriorBox()->flip(); }
|
||||
float PriorBox::GetOffset() const { return this->primitive->value_as_PriorBox()->offset(); }
|
||||
|
||||
void PriorBox::SetMinSizes(const std::vector<int> &min_sizes) {}
|
||||
void PriorBox::SetMaxSizes(const std::vector<int> &max_sizes) {}
|
||||
void PriorBox::SetAspectRatios(const std::vector<float> &aspect_ratios) {}
|
||||
void PriorBox::SetVariances(const std::vector<float> &variances) {}
|
||||
void PriorBox::SetImageSizeW(int image_size_w) {}
|
||||
void PriorBox::SetImageSizeH(int image_size_h) {}
|
||||
void PriorBox::SetStepW(float step_w) {}
|
||||
void PriorBox::SetStepH(float step_h) {}
|
||||
void PriorBox::SetClip(bool clip) {}
|
||||
void PriorBox::SetFlip(bool flip) {}
|
||||
void PriorBox::SetOffset(float offset) {}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr int kPriorBoxPoints = 4;
|
||||
constexpr int kPriorBoxN = 1;
|
||||
constexpr int kPriorBoxW = 1;
|
||||
constexpr int kPriorBoxC = 2;
|
||||
} // namespace
|
||||
|
||||
int PriorBox::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
std::vector<float> different_aspect_ratios{1.0f};
|
||||
auto aspect_ratios = GetAspectRatios();
|
||||
MS_ASSERT(aspect_ratios != nullptr);
|
||||
for (auto i = 0; i < aspect_ratios.size(); i++) {
|
||||
float ratio = (aspect_ratios)[i];
|
||||
bool exist = std::any_of(different_aspect_ratios.begin(), different_aspect_ratios.end(),
|
||||
[&](float v) { return abs(ratio - v) < 1e-6; });
|
||||
if (!exist) {
|
||||
different_aspect_ratios.emplace_back(ratio);
|
||||
if (GetFlip()) {
|
||||
different_aspect_ratios.emplace_back(1.0f / ratio);
|
||||
}
|
||||
}
|
||||
}
|
||||
int32_t num_priors_box = GetMinSizes().size() * different_aspect_ratios.size() + GetMaxSizes().size();
|
||||
auto input = inputs_.at(0);
|
||||
MS_ASSERT(input != nullptr);
|
||||
int32_t h = input->Height() * input->Width() * num_priors_box * kPriorBoxPoints;
|
||||
|
||||
std::vector<int> output_shape{kPriorBoxN, h, kPriorBoxW, kPriorBoxC};
|
||||
auto output = outputs_.at(0);
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
output->set_shape(output_shape);
|
||||
output->set_data_type(kNumberTypeFloat32);
|
||||
output->SetFormat(input->GetFormat());
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,49 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/quant_dtype_cast.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int QuantDTypeCast::GetSrcT() const { return this->primitive->value.AsQuantDTypeCast()->srcT; }
|
||||
int QuantDTypeCast::GetDstT() const { return this->primitive->value.AsQuantDTypeCast()->dstT; }
|
||||
|
||||
void QuantDTypeCast::SetSrcT(int src_t) { this->primitive->value.AsQuantDTypeCast()->srcT = src_t; }
|
||||
void QuantDTypeCast::SetDstT(int dst_t) { this->primitive->value.AsQuantDTypeCast()->dstT = dst_t; }
|
||||
|
||||
#else
|
||||
|
||||
int QuantDTypeCast::GetSrcT() const { return this->primitive->value_as_QuantDTypeCast()->srcT(); }
|
||||
int QuantDTypeCast::GetDstT() const { return this->primitive->value_as_QuantDTypeCast()->dstT(); }
|
||||
|
||||
void QuantDTypeCast::SetSrcT(int src_t) {}
|
||||
void QuantDTypeCast::SetDstT(int dst_t) {}
|
||||
#endif
|
||||
int QuantDTypeCast::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
|
||||
std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_shape(input->shape());
|
||||
auto param = primitive->value_as_QuantDTypeCast();
|
||||
MS_ASSERT(input->data_type() == param->srcT);
|
||||
output->set_data_type(static_cast<TypeId>(param->dstT()));
|
||||
output->SetFormat(input->GetFormat());
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,59 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/range.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Range::GetDType() const { return this->primitive->value.AsRange()->dType; }
|
||||
int Range::GetStart() const { return this->primitive->value.AsRange()->start; }
|
||||
int Range::GetLimit() const { return this->primitive->value.AsRange()->limit; }
|
||||
int Range::GetDelta() const { return this->primitive->value.AsRange()->delta; }
|
||||
|
||||
void Range::SetDType(int d_type) { this->primitive->value.AsRange()->dType = d_type; }
|
||||
void Range::SetStart(int start) { this->primitive->value.AsRange()->start = start; }
|
||||
void Range::SetLimit(int limit) { this->primitive->value.AsRange()->limit = limit; }
|
||||
void Range::SetDelta(int delta) { this->primitive->value.AsRange()->delta = delta; }
|
||||
|
||||
#else
|
||||
|
||||
int Range::GetDType() const { return this->primitive->value_as_Range()->dType(); }
|
||||
int Range::GetStart() const { return this->primitive->value_as_Range()->start(); }
|
||||
int Range::GetLimit() const { return this->primitive->value_as_Range()->limit(); }
|
||||
int Range::GetDelta() const { return this->primitive->value_as_Range()->delta(); }
|
||||
|
||||
void Range::SetDType(int d_type) {}
|
||||
void Range::SetStart(int start) {}
|
||||
void Range::SetLimit(int limit) {}
|
||||
void Range::SetDelta(int delta) {}
|
||||
#endif
|
||||
int Range::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
int shape_size = std::ceil(static_cast<float>(GetLimit() - GetStart()) / GetDelta());
|
||||
std::vector<int> in_shape(1);
|
||||
in_shape.push_back(shape_size);
|
||||
output->set_shape(in_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,33 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the License);
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/rank.h"
|
||||
|
||||
namespace mindspore {
|
||||
int Rank::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
std::vector<int> in_shape(1, 1);
|
||||
output->set_shape(in_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,99 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/reduce.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
std::vector<int> Reduce::GetAxes() const { return this->primitive->value.AsReduce()->axes; }
|
||||
int Reduce::GetKeepDims() const { return this->primitive->value.AsReduce()->keepDims; }
|
||||
int Reduce::GetMode() const { return this->primitive->value.AsReduce()->mode; }
|
||||
|
||||
void Reduce::SetAxes(const std::vector<int> &axes) { this->primitive->value.AsReduce()->axes = axes; }
|
||||
void Reduce::SetKeepDims(int keep_dims) { this->primitive->value.AsReduce()->keepDims = keep_dims; }
|
||||
void Reduce::SetMode(int mode) { this->primitive->value.AsReduce()->mode = (schema::ReduceMode)mode; }
|
||||
|
||||
#else
|
||||
|
||||
std::vector<int> Reduce::GetAxes() const {
|
||||
auto fb_vector = this->primitive->value_as_Reduce()->axes();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
int Reduce::GetKeepDims() const { return this->primitive->value_as_Reduce()->keepDims(); }
|
||||
int Reduce::GetMode() const { return this->primitive->value_as_Reduce()->mode(); }
|
||||
|
||||
void Reduce::SetAxes(const std::vector<int> &axes) {}
|
||||
void Reduce::SetKeepDims(int keep_dims) {}
|
||||
void Reduce::SetMode(int mode) {}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr size_t kInputSize = 1;
|
||||
constexpr size_t kOutputSize = 1;
|
||||
} // namespace
|
||||
int Reduce::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
if (inputs_.size() != kInputSize || outputs_.size() != kOutputSize) {
|
||||
return 1;
|
||||
}
|
||||
auto input = inputs_.front();
|
||||
auto output = outputs_.front();
|
||||
if (input == nullptr || output == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
if (this->primitive == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
bool keep_dims = static_cast<bool>(GetKeepDims());
|
||||
std::vector<int> in_shape = input->shape();
|
||||
std::vector<int> out_shape;
|
||||
const auto &axes = GetAxes();
|
||||
auto num_axes = axes.size();
|
||||
// reduce on all axes
|
||||
if (num_axes == 0) {
|
||||
if (keep_dims) {
|
||||
for (auto i = 0; i < in_shape.size(); i++) {
|
||||
out_shape.push_back(1);
|
||||
}
|
||||
}
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
return 0;
|
||||
}
|
||||
|
||||
// reduce on selected axes
|
||||
for (size_t i = 0; i < in_shape.size(); i++) {
|
||||
bool reduce_axis = false;
|
||||
for (int idx = 0; idx < num_axes; ++idx) {
|
||||
if (static_cast<size_t>((axes)[idx]) == i) {
|
||||
reduce_axis = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (reduce_axis) {
|
||||
if (keep_dims) {
|
||||
out_shape.push_back(1);
|
||||
}
|
||||
} else {
|
||||
out_shape.push_back(in_shape[i]);
|
||||
}
|
||||
}
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,153 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/reshape.h"
|
||||
#include <algorithm>
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Reshape::GetFormat() const { return this->primitive->value.AsReshape()->format; }
|
||||
std::vector<long> Reshape::GetShape() const { return this->primitive->value.AsReshape()->shape; }
|
||||
|
||||
void Reshape::SetFormat(int format) { this->primitive->value.AsReshape()->format = format; }
|
||||
void Reshape::SetShape(const std::vector<long> &shape) { this->primitive->value.AsReshape()->shape = shape; }
|
||||
|
||||
#else
|
||||
|
||||
int Reshape::GetFormat() const { return this->primitive->value_as_Reshape()->format(); }
|
||||
std::vector<long> Reshape::GetShape() const {
|
||||
auto fb_vector = this->primitive->value_as_Reshape()->shape();
|
||||
return std::vector<long>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
void Reshape::SetFormat(int format) {}
|
||||
void Reshape::SetShape(const std::vector<long> &shape) {}
|
||||
#endif
|
||||
|
||||
int Reshape::CalNewShape(const lite::tensor::Tensor *in_tensor, std::vector<int> *out_shape) const {
|
||||
size_t in_shape_size = 1;
|
||||
for (size_t i = 0; i < in_tensor->shape().size(); i++) {
|
||||
in_shape_size *= in_tensor->shape()[i];
|
||||
}
|
||||
|
||||
int64_t inferIndex = -1;
|
||||
size_t out_shapeSize = 1;
|
||||
for (size_t i = 0; i < out_shape->size(); i++) {
|
||||
if (out_shape->at(i) == -1) {
|
||||
if (inferIndex == -1) {
|
||||
inferIndex = i;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "output shape should has no more than one dim which need infer";
|
||||
return 1;
|
||||
}
|
||||
} else if (out_shape->at(i) < 0) {
|
||||
MS_LOG(ERROR) << "output shape dim should be non-negative";
|
||||
return 1;
|
||||
} else if (out_shape->at(i) == 0) {
|
||||
out_shape->at(i) = in_tensor->shape().at(i);
|
||||
out_shapeSize *= out_shape->at(i);
|
||||
} else {
|
||||
out_shapeSize *= out_shape->at(i);
|
||||
}
|
||||
}
|
||||
|
||||
if (inferIndex == -1 && out_shapeSize != in_shape_size) {
|
||||
MS_LOG(ERROR) << "output shapeSize: " << out_shapeSize << " should be equal to input shapeSize: " << in_shape_size;
|
||||
return 1;
|
||||
}
|
||||
if (inferIndex != -1) {
|
||||
out_shape->at(inferIndex) = in_shape_size / out_shapeSize;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalShape(const T *data, const std::vector<lite::tensor::Tensor *> &inputs, std::vector<int> *out_shape,
|
||||
int shape_size) {
|
||||
int input_count = inputs[0]->ElementsNum();
|
||||
|
||||
int index = 0;
|
||||
int size = 1;
|
||||
for (size_t i = 0; i < shape_size; i++) {
|
||||
if (data[i] == -1) {
|
||||
index = i;
|
||||
} else {
|
||||
size *= data[i];
|
||||
}
|
||||
out_shape->push_back(data[i]);
|
||||
}
|
||||
if (data[index] == -1) {
|
||||
(*out_shape)[index] = input_count / size;
|
||||
}
|
||||
}
|
||||
|
||||
int Reshape::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
std::vector<int> out_shape;
|
||||
if (inputs_.size() == kDoubleNum) {
|
||||
auto shape_tensor = inputs_.at(1);
|
||||
if (shape_tensor->Data() == nullptr) {
|
||||
MS_LOG(INFO) << "Do infer shape in runtime.";
|
||||
return 1;
|
||||
}
|
||||
size_t shape_size = shape_tensor->ElementsNum();
|
||||
switch (shape_tensor->data_type()) {
|
||||
case kNumberTypeInt8: {
|
||||
auto data = reinterpret_cast<int8_t *>(shape_tensor->Data());
|
||||
CalShape<int8_t>(data, inputs_, &out_shape, shape_size);
|
||||
} break;
|
||||
case kNumberTypeInt32: {
|
||||
auto data = reinterpret_cast<int32_t *>(shape_tensor->Data());
|
||||
CalShape<int32_t>(data, inputs_, &out_shape, shape_size);
|
||||
} break;
|
||||
case kNumberTypeFloat: {
|
||||
auto data = reinterpret_cast<float *>(shape_tensor->Data());
|
||||
CalShape<float>(data, inputs_, &out_shape, shape_size);
|
||||
} break;
|
||||
case kNumberTypeUInt32: {
|
||||
auto data = reinterpret_cast<uint32_t *>(shape_tensor->Data());
|
||||
CalShape<uint32_t>(data, inputs_, &out_shape, shape_size);
|
||||
} break;
|
||||
default: {
|
||||
MS_LOG(ERROR) << "Reshape weight tensor has unsupported dataType: " << shape_tensor->data_type();
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
} else if (inputs_.size() == kSingleNum) {
|
||||
std::copy(GetShape().begin(), GetShape().end(), std::back_inserter(out_shape));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "inputs tensor size invalid.";
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto ret = CalNewShape(inputs_.front(), &out_shape);
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "CalNewShape error";
|
||||
return ret;
|
||||
}
|
||||
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,82 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/resize.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Resize::GetFormat() const { return this->primitive->value.AsResize()->format; }
|
||||
int Resize::GetMethod() const { return this->primitive->value.AsResize()->method; }
|
||||
long Resize::GetNewHeight() const { return this->primitive->value.AsResize()->newHeight; }
|
||||
long Resize::GetNewWidth() const { return this->primitive->value.AsResize()->newWidth; }
|
||||
bool Resize::GetAlignCorners() const { return this->primitive->value.AsResize()->alignCorners; }
|
||||
bool Resize::GetPreserveAspectRatio() const { return this->primitive->value.AsResize()->preserveAspectRatio; }
|
||||
|
||||
void Resize::SetFormat(int format) { this->primitive->value.AsResize()->format = (schema::Format)format; }
|
||||
void Resize::SetMethod(int method) { this->primitive->value.AsResize()->method = (schema::ResizeMethod)method; }
|
||||
void Resize::SetNewHeight(long new_height) { this->primitive->value.AsResize()->newHeight = new_height; }
|
||||
void Resize::SetNewWidth(long new_width) { this->primitive->value.AsResize()->newWidth = new_width; }
|
||||
void Resize::SetAlignCorners(bool align_corners) { this->primitive->value.AsResize()->alignCorners = align_corners; }
|
||||
void Resize::SetPreserveAspectRatio(bool preserve_aspect_ratio) {
|
||||
this->primitive->value.AsResize()->preserveAspectRatio = preserve_aspect_ratio;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
int Resize::GetFormat() const { return this->primitive->value_as_Resize()->format(); }
|
||||
int Resize::GetMethod() const { return this->primitive->value_as_Resize()->method(); }
|
||||
long Resize::GetNewHeight() const { return this->primitive->value_as_Resize()->newHeight(); }
|
||||
long Resize::GetNewWidth() const { return this->primitive->value_as_Resize()->newWidth(); }
|
||||
bool Resize::GetAlignCorners() const { return this->primitive->value_as_Resize()->alignCorners(); }
|
||||
bool Resize::GetPreserveAspectRatio() const { return this->primitive->value_as_Resize()->preserveAspectRatio(); }
|
||||
|
||||
void Resize::SetFormat(int format) {}
|
||||
void Resize::SetMethod(int method) {}
|
||||
void Resize::SetNewHeight(long new_height) {}
|
||||
void Resize::SetNewWidth(long new_width) {}
|
||||
void Resize::SetAlignCorners(bool align_corners) {}
|
||||
void Resize::SetPreserveAspectRatio(bool preserve_aspect_ratio) {}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr int kInputRank = 4;
|
||||
} // namespace
|
||||
int Resize::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
if (input == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
MS_ASSERT(input->shape().size() == kInputRank);
|
||||
|
||||
auto output = outputs_.front();
|
||||
if (output == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
auto new_height = GetNewHeight();
|
||||
auto new_width = GetNewWidth();
|
||||
|
||||
std::vector<int> output_shape;
|
||||
output_shape.push_back(input->Batch());
|
||||
output_shape.push_back(new_height);
|
||||
output_shape.push_back(new_width);
|
||||
output_shape.push_back(input->Channel());
|
||||
output->set_shape(output_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,60 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/reverse_sequence.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int ReverseSequence::GetSeqAxis() const { return this->primitive->value.AsReverseSequence()->seqAxis; }
|
||||
int ReverseSequence::GetBatchAxis() const { return this->primitive->value.AsReverseSequence()->batchAxis; }
|
||||
std::vector<int> ReverseSequence::GetSeqLengths() const {
|
||||
return this->primitive->value.AsReverseSequence()->seqLengths;
|
||||
}
|
||||
|
||||
void ReverseSequence::SetSeqAxis(int seq_axis) { this->primitive->value.AsReverseSequence()->seqAxis = seq_axis; }
|
||||
void ReverseSequence::SetBatchAxis(int batch_axis) {
|
||||
this->primitive->value.AsReverseSequence()->batchAxis = batch_axis;
|
||||
}
|
||||
void ReverseSequence::SetSeqLengths(const std::vector<int> &seq_lengths) {
|
||||
this->primitive->value.AsReverseSequence()->seqLengths = seq_lengths;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
int ReverseSequence::GetSeqAxis() const { return this->primitive->value_as_ReverseSequence()->seqAxis(); }
|
||||
int ReverseSequence::GetBatchAxis() const { return this->primitive->value_as_ReverseSequence()->batchAxis(); }
|
||||
std::vector<int> ReverseSequence::GetSeqLengths() const {
|
||||
auto fb_vector = this->primitive->value_as_ReverseSequence()->seqLengths();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
void ReverseSequence::SetSeqAxis(int seq_axis) {}
|
||||
void ReverseSequence::SetBatchAxis(int batch_axis) {}
|
||||
void ReverseSequence::SetSeqLengths(const std::vector<int> &seq_lengths) {}
|
||||
#endif
|
||||
int ReverseSequence::InferShape(std::vector<lite::tensor::Tensor *> inputs,
|
||||
std::vector<lite::tensor::Tensor *> outputs) {
|
||||
auto input = inputs.front();
|
||||
auto output = outputs.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
output->set_shape(input->shape());
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,75 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/roi_pooling.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int ROIPooling::GetPooledH() const { return this->primitive->value.AsROIPooling()->pooledH; }
|
||||
int ROIPooling::GetPooledW() const { return this->primitive->value.AsROIPooling()->pooledW; }
|
||||
float ROIPooling::GetScale() const { return this->primitive->value.AsROIPooling()->scale; }
|
||||
|
||||
void ROIPooling::SetPooledH(int pooled_h) { this->primitive->value.AsROIPooling()->pooledH = pooled_h; }
|
||||
void ROIPooling::SetPooledW(int pooled_w) { this->primitive->value.AsROIPooling()->pooledW = pooled_w; }
|
||||
void ROIPooling::SetScale(float scale) { this->primitive->value.AsROIPooling()->scale = scale; }
|
||||
|
||||
#else
|
||||
|
||||
int ROIPooling::GetPooledH() const { return this->primitive->value_as_ROIPooling()->pooledH(); }
|
||||
int ROIPooling::GetPooledW() const { return this->primitive->value_as_ROIPooling()->pooledW(); }
|
||||
float ROIPooling::GetScale() const { return this->primitive->value_as_ROIPooling()->scale(); }
|
||||
|
||||
void ROIPooling::SetPooledH(int pooled_h) {}
|
||||
void ROIPooling::SetPooledW(int pooled_w) {}
|
||||
void ROIPooling::SetScale(float scale) {}
|
||||
#endif
|
||||
|
||||
int ROIPooling::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (inputs_.size() != kDoubleNum) {
|
||||
MS_LOG(ERROR) << "inputs number is not equal to " << kDoubleNum;
|
||||
return 1;
|
||||
}
|
||||
auto input = inputs_.front();
|
||||
if (input == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
auto roi = inputs_.at(1);
|
||||
if (roi == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
auto output = outputs_.front();
|
||||
if (output == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto new_h = GetPooledH();
|
||||
auto new_w = GetPooledW();
|
||||
|
||||
auto shape_data = roi->shape();
|
||||
|
||||
std::vector<int> output_shape;
|
||||
output_shape.push_back(shape_data[0]);
|
||||
output_shape.push_back(new_h);
|
||||
output_shape.push_back(new_w);
|
||||
output_shape.push_back(input->Channel());
|
||||
output->set_shape(output_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,61 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the License);
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/scatter_nd.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
constexpr int kScatterNDInputNum = 3;
|
||||
constexpr int kScatterNDOutputNum = 1;
|
||||
constexpr int kScatterShapeIndex = 0;
|
||||
constexpr int kScatterIndicesIndex = 1;
|
||||
constexpr int kScatterUpdateIndex = 2;
|
||||
} // namespace
|
||||
|
||||
int ScatterND::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
if (inputs_.size() != kScatterNDInputNum) {
|
||||
MS_LOG(ERROR) << "inputs number is not equal to " << kScatterNDInputNum;
|
||||
return 1;
|
||||
}
|
||||
if (outputs_.size() != kScatterNDOutputNum) {
|
||||
MS_LOG(ERROR) << "outputs number is not equal to " << kScatterNDInputNum;
|
||||
return 1;
|
||||
}
|
||||
auto shape = inputs_.at(kScatterShapeIndex);
|
||||
if (shape == nullptr) {
|
||||
MS_LOG(ERROR) << "shape null pointer dereferencing.";
|
||||
return 1;
|
||||
}
|
||||
auto indices = inputs_.at(kScatterIndicesIndex);
|
||||
if (indices == nullptr) {
|
||||
MS_LOG(ERROR) << "indices null pointer dereferencing.";
|
||||
return 1;
|
||||
}
|
||||
auto update = inputs_.at(kScatterUpdateIndex);
|
||||
if (update == nullptr) {
|
||||
MS_LOG(ERROR) << "update null pointer dereferencing.";
|
||||
return 1;
|
||||
}
|
||||
auto output = outputs_.front();
|
||||
auto shape_data = reinterpret_cast<int *>(shape->Data());
|
||||
std::vector<int> out_shape(shape_data, shape_data + shape->DataSize());
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(update->data_type());
|
||||
output->SetFormat(update->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,52 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/shape.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
constexpr int kShapeInputNum = 1;
|
||||
constexpr int kShapeOutputNum = 1;
|
||||
} // namespace
|
||||
int Shape::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
if (inputs_.size() != kShapeInputNum) {
|
||||
MS_LOG(ERROR) << "inputs to Shape operator should be 1, but " << inputs_.size() << " is given.";
|
||||
return 1;
|
||||
}
|
||||
if (outputs_.size() != kShapeOutputNum) {
|
||||
MS_LOG(ERROR) << "outputs to Shape operator should be 1, but " << outputs_.size() << " is given.";
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto in_tensor = inputs_.front();
|
||||
auto out_tensor = outputs_.front();
|
||||
std::vector<int> out_shape;
|
||||
out_shape.push_back(static_cast<int>(in_tensor->shape().size()));
|
||||
|
||||
auto ret_shape = out_tensor->set_shape(out_shape);
|
||||
if (ret_shape != 1 || size_t(out_tensor->shape()[0]) != in_tensor->shape().size()) {
|
||||
MS_LOG(ERROR) << "Set shape fails.";
|
||||
return 1;
|
||||
}
|
||||
auto ret_dtype = out_tensor->set_data_type(in_tensor->data_type());
|
||||
if (ret_dtype != in_tensor->data_type()) {
|
||||
MS_LOG(ERROR) << "Set datatype fails.";
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,90 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/slice.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
constexpr int kSliceInputNum = 1;
|
||||
constexpr int kSliceOutputNum = 1;
|
||||
} // namespace
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int SliceOp::GetFormat() const { return this->primitive->value.AsSlice()->format; }
|
||||
std::vector<int> SliceOp::GetBegin() const { return this->primitive->value.AsSlice()->begin; }
|
||||
std::vector<int> SliceOp::GetSize() const { return this->primitive->value.AsSlice()->size; }
|
||||
|
||||
void SliceOp::SetFormat(int format) { this->primitive->value.AsSlice()->format = format; }
|
||||
void SliceOp::SetBegin(const std::vector<int> &begin) { this->primitive->value.AsSlice()->begin = begin; }
|
||||
void SliceOp::SetSize(const std::vector<int> &size) { this->primitive->value.AsSlice()->size = size; }
|
||||
|
||||
#else
|
||||
|
||||
int SliceOp::GetFormat() const { return this->primitive->value_as_Slice()->format(); }
|
||||
std::vector<int> SliceOp::GetBegin() const {
|
||||
auto fb_vector = this->primitive->value_as_Slice()->begin();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
std::vector<int> SliceOp::GetSize() const {
|
||||
auto fb_vector = this->primitive->value_as_Slice()->size();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
void SliceOp::SetFormat(int format) {}
|
||||
void SliceOp::SetBegin(const std::vector<int> &begin) {}
|
||||
void SliceOp::SetSize(const std::vector<int> &size) {}
|
||||
#endif
|
||||
|
||||
int SliceOp::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (inputs.size() != kSliceInputNum || outputs.size() != kSliceOutputNum) {
|
||||
MS_LOG(ERROR) << "input size:" << inputs.size() << ",output size:" << outputs.size();
|
||||
return 1;
|
||||
}
|
||||
auto input = inputs.at(0);
|
||||
auto input_shape = input->shape();
|
||||
std::vector<int32_t> slice_begin(GetBegin().begin(), GetBegin().end());
|
||||
std::vector<int32_t> slice_size(GetSize().begin(), GetSize().end());
|
||||
std::vector<int32_t> output_shape(input_shape.size());
|
||||
for (int i = 0; i < input_shape.size(); ++i) {
|
||||
if (slice_size[i] < 0 && slice_size[i] != -1) {
|
||||
MS_LOG(ERROR) << "Invalid size input!size[" << i << "]=" << slice_size[i];
|
||||
return 1;
|
||||
}
|
||||
if (slice_begin[i] < 0) {
|
||||
MS_LOG(ERROR) << "Invalid begin input " << slice_begin[i] << " which should be >= 0";
|
||||
return 1;
|
||||
}
|
||||
if (input_shape[i] <= slice_begin[i]) {
|
||||
MS_LOG(ERROR) << "Invalid begin input!begin[" << i << "]=" << slice_begin[i]
|
||||
<< " which should be <= " << input_shape[i];
|
||||
return 1;
|
||||
}
|
||||
if (slice_size[i] > (input_shape[i] - slice_begin[i])) {
|
||||
MS_LOG(ERROR) << "Invalid size input " << slice_size[i]
|
||||
<< " which should be <= " << input_shape[i] - slice_begin[i];
|
||||
return 1;
|
||||
}
|
||||
|
||||
output_shape[i] = slice_size[i] < 0 ? input_shape[i] - slice_begin[i] : slice_size[i];
|
||||
}
|
||||
|
||||
outputs[0]->set_shape(output_shape);
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,43 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/softmax.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int SoftMax::GetAxis() const { return this->primitive->value.AsSoftMax()->axis; }
|
||||
|
||||
void SoftMax::SetAxis(int axis) { this->primitive->value.AsSoftMax()->axis = axis; }
|
||||
|
||||
#else
|
||||
|
||||
int SoftMax::GetAxis() const { return this->primitive->value_as_SoftMax()->axis(); }
|
||||
|
||||
void SoftMax::SetAxis(int axis) {}
|
||||
#endif
|
||||
int SoftMax::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_shape(input->shape());
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,110 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/space_to_batch.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
std::vector<int> SpaceToBatch::GetBlockShape() const { return this->primitive->value.AsSpaceToBatch()->blockShape; }
|
||||
std::vector<int> SpaceToBatch::GetPaddings() const { return this->primitive->value.AsSpaceToBatch()->paddings; }
|
||||
|
||||
void SpaceToBatch::SetBlockShape(const std::vector<int> &block_shape) {
|
||||
this->primitive->value.AsSpaceToBatch()->blockShape = block_shape;
|
||||
}
|
||||
void SpaceToBatch::SetPaddings(const std::vector<int> &paddings) {
|
||||
this->primitive->value.AsSpaceToBatch()->paddings = paddings;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
std::vector<int> SpaceToBatch::GetBlockShape() const {
|
||||
auto fb_vector = this->primitive->value_as_SpaceToBatch()->blockShape();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
std::vector<int> SpaceToBatch::GetPaddings() const {
|
||||
auto fb_vector = this->primitive->value_as_SpaceToBatch()->paddings();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
void SpaceToBatch::SetBlockShape(const std::vector<int> &block_shape) {}
|
||||
void SpaceToBatch::SetPaddings(const std::vector<int> &paddings) {}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr int kSpaceToBatchNDOutputNum = 1;
|
||||
constexpr int kSpaceToBatchNDInputNum = 1;
|
||||
constexpr int kBlockSizesSize = 2;
|
||||
constexpr int kPaddingsSize = 4;
|
||||
} // namespace
|
||||
|
||||
int SpaceToBatch::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (outputs.size() != kSpaceToBatchNDOutputNum || inputs.size() != kSpaceToBatchNDInputNum) {
|
||||
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto input = inputs.at(0);
|
||||
if (input->GetFormat() != schema::Format_NHWC) {
|
||||
MS_LOG(ERROR) << "space_to_batch only support NHWC now!";
|
||||
return 1;
|
||||
}
|
||||
auto input_shape = input->shape();
|
||||
if (input_shape.size() != kDimension_4d) {
|
||||
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (GetBlockShape().size() != kBlockSizesSize) {
|
||||
MS_LOG(ERROR) << "Block shape size should be " << kBlockSizesSize;
|
||||
return 1;
|
||||
}
|
||||
if (GetPaddings().size() != kPaddingsSize) {
|
||||
MS_LOG(ERROR) << "Crops size should be " << kPaddingsSize;
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (int &iter : GetBlockShape()) {
|
||||
block_sizes_.emplace_back(iter);
|
||||
}
|
||||
|
||||
in_shape_.clear();
|
||||
padded_in_shape_.clear();
|
||||
paddings_.clear();
|
||||
in_shape_.emplace_back(input_shape.at(NHWC_N));
|
||||
padded_in_shape_.emplace_back(input_shape.at(NHWC_N));
|
||||
for (int i = 0; i < kBlockSizesSize; i++) {
|
||||
in_shape_.emplace_back(input_shape.at(i + 1));
|
||||
padded_in_shape_.emplace_back(input_shape.at(i + 1) + (paddings_.at(2 * i) + paddings_.at(2 * i + 1)));
|
||||
paddings_.emplace_back(paddings_.at(2 * i));
|
||||
paddings_.emplace_back(paddings_.at(2 * i + 1));
|
||||
if (paddings_.back() % block_sizes_.at(i)) {
|
||||
MS_LOG(ERROR) << "Padded shape does not divide block size " << block_sizes_.at(i);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
in_shape_.emplace_back(input_shape.at(NHWC_C));
|
||||
padded_in_shape_.emplace_back(input_shape.at(NHWC_C));
|
||||
|
||||
std::vector<int32_t> output_shape(input_shape.size());
|
||||
output_shape[NHWC_N] = input_shape[NHWC_N] * (block_sizes_[NHWC_N] * block_sizes_[NHWC_H]);
|
||||
output_shape[NHWC_H] = input_shape[NHWC_H] / block_sizes_[NHWC_N];
|
||||
output_shape[NHWC_W] = input_shape[NHWC_W] / block_sizes_[NHWC_H];
|
||||
output_shape[NHWC_C] = input_shape[NHWC_C];
|
||||
outputs[0]->set_shape(output_shape);
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,73 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/space_to_depth.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int SpaceToDepth::GetBlockSize() const { return this->primitive->value.AsSpaceToDepth()->blockSize; }
|
||||
int SpaceToDepth::GetFormat() const { return this->primitive->value.AsSpaceToDepth()->format; }
|
||||
|
||||
void SpaceToDepth::SetBlockSize(int block_size) { this->primitive->value.AsSpaceToDepth()->blockSize = block_size; }
|
||||
void SpaceToDepth::SetFormat(int format) { this->primitive->value.AsSpaceToDepth()->format = format; }
|
||||
|
||||
#else
|
||||
|
||||
int SpaceToDepth::GetBlockSize() const { return this->primitive->value_as_SpaceToDepth()->blockSize(); }
|
||||
int SpaceToDepth::GetFormat() const { return this->primitive->value_as_SpaceToDepth()->format(); }
|
||||
|
||||
void SpaceToDepth::SetBlockSize(int block_size) {}
|
||||
void SpaceToDepth::SetFormat(int format) {}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr int kSpaceToDepthOutputNum = 1;
|
||||
constexpr int kSpaceToDepthInputNum = 1;
|
||||
} // namespace
|
||||
|
||||
int SpaceToDepth::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (outputs.size() != kSpaceToDepthOutputNum || inputs.size() != kSpaceToDepthInputNum) {
|
||||
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto input = inputs.at(0);
|
||||
if (input->GetFormat() != schema::Format_NHWC) {
|
||||
MS_LOG(ERROR) << "space_to_depth only support NHWC now!";
|
||||
return 1;
|
||||
}
|
||||
auto input_shape = input->shape();
|
||||
if (input_shape.size() != kDimension_4d) {
|
||||
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
|
||||
return 1;
|
||||
}
|
||||
|
||||
int32_t block_size = GetBlockSize();
|
||||
if (input_shape[NHWC_C] % (block_size * block_size) != 0 || input_shape[NHWC_C] == 0) {
|
||||
MS_LOG(ERROR) << "input dimension c size " << input_shape[NHWC_C] << " should be mulitple of block_size("
|
||||
<< block_size << ") * block_size)!";
|
||||
return 1;
|
||||
}
|
||||
std::vector<int32_t> output_shape(input_shape.size());
|
||||
output_shape[NHWC_N] = input_shape[NHWC_N];
|
||||
output_shape[NHWC_H] = input_shape[NHWC_H] / block_size;
|
||||
output_shape[NHWC_W] = input_shape[NHWC_W] / block_size;
|
||||
output_shape[NHWC_C] = input_shape[NHWC_C] * (block_size * block_size);
|
||||
outputs[0]->set_shape(output_shape);
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,83 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/split.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Split::GetNumberSplit() const { return this->primitive->value.AsSplit()->numberSplit; }
|
||||
std::vector<int> Split::GetSizeSplits() const { return this->primitive->value.AsSplit()->sizeSplits; }
|
||||
int Split::GetSplitDim() const { return this->primitive->value.AsSplit()->splitDim; }
|
||||
|
||||
void Split::SetNumberSplit(int number_split) { this->primitive->value.AsSplit()->numberSplit = number_split; }
|
||||
void Split::SetSizeSplits(const std::vector<int> &size_splits) {
|
||||
this->primitive->value.AsSplit()->sizeSplits = size_splits;
|
||||
}
|
||||
void Split::SetSplitDim(int split_dim) { this->primitive->value.AsSplit()->splitDim = split_dim; }
|
||||
|
||||
#else
|
||||
|
||||
int Split::GetNumberSplit() const { return this->primitive->value_as_Split()->numberSplit(); }
|
||||
std::vector<int> Split::GetSizeSplits() const {
|
||||
auto fb_vector = this->primitive->value_as_Split()->sizeSplits();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
int Split::GetSplitDim() const { return this->primitive->value_as_Split()->splitDim(); }
|
||||
|
||||
void Split::SetNumberSplit(int number_split) {}
|
||||
void Split::SetSizeSplits(const std::vector<int> &size_splits) {}
|
||||
void Split::SetSplitDim(int split_dim) {}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr int kSplitInputNum = 1;
|
||||
} // namespace
|
||||
int Split::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
|
||||
MS_ASSERT(spilt_prim != nullptr);
|
||||
if (inputs_.size() != kSplitInputNum) {
|
||||
MS_LOG(ERROR) << "inputs number is not equal to " << kSplitInputNum;
|
||||
return 1;
|
||||
}
|
||||
auto output = outputs_.front();
|
||||
if (output == nullptr) {
|
||||
MS_LOG(ERROR) << "output null pointer dereferencing.";
|
||||
return 1;
|
||||
}
|
||||
int number_split = GetNumberSplit();
|
||||
if (outputs_.size() != number_split) {
|
||||
MS_LOG(ERROR) << "outputs number is not equal to " << number_split;
|
||||
return 1;
|
||||
}
|
||||
int split_dim = GetSplitDim();
|
||||
std::vector<int> input_shape = input->shape();
|
||||
std::vector<int> size_split;
|
||||
size_split.insert(size_split.begin(), GetSizeSplits().begin(), GetSizeSplits().end());
|
||||
|
||||
for (int i = 0; i < number_split; ++i) {
|
||||
std::vector<int> output_shape;
|
||||
output_shape.insert(output_shape.begin(), input_shape.begin(), input_shape.end());
|
||||
auto split_dim_i = size_split.empty() ? input_shape[split_dim] / number_split : size_split[i];
|
||||
output_shape[split_dim] = split_dim_i;
|
||||
outputs_[i]->set_shape(output_shape);
|
||||
outputs_[i]->set_data_type(input->data_type());
|
||||
outputs_[i]->SetFormat(input->GetFormat());
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,84 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/squeeze.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
std::vector<int> Squeeze::GetAxis() const { return this->primitive->value.AsSqueeze()->axis; }
|
||||
|
||||
void Squeeze::SetAxis(const std::vector<int> &axis) { this->primitive->value.AsSqueeze()->axis = axis; }
|
||||
|
||||
#else
|
||||
|
||||
std::vector<int> Squeeze::GetAxis() const {
|
||||
auto fb_vector = this->primitive->value_as_Squeeze()->axis();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
void Squeeze::SetAxis(const std::vector<int> &axis) {}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr int kSqueezeInputNum = 1;
|
||||
constexpr int kSqueezeOutputNum = 1;
|
||||
} // namespace
|
||||
int Squeeze::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (kSqueezeInputNum != inputs_.size()) {
|
||||
MS_LOG(ERROR) << "Add should has " << kSqueezeInputNum << " inputs";
|
||||
return -1;
|
||||
}
|
||||
if (kSqueezeOutputNum != outputs_.size()) {
|
||||
MS_LOG(ERROR) << "Add should has " << kSqueezeOutputNum << " outputs";
|
||||
return -1;
|
||||
}
|
||||
auto *in_tensor = inputs_.front();
|
||||
auto in_shape = in_tensor->shape();
|
||||
std::vector<int> out_shape;
|
||||
|
||||
// todo: getAxis
|
||||
auto axis = GetAxis();
|
||||
std::vector<int> axes_;
|
||||
for (auto iter = axis.begin(); iter != axis.end(); iter++) {
|
||||
axes_.push_back(*iter);
|
||||
}
|
||||
|
||||
if (axes_.size() == 0) {
|
||||
for (int i = 0; i < in_shape.size(); i++) {
|
||||
if (in_shape[i] != 1) {
|
||||
out_shape.push_back(in_shape[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int axisIdx = 0;
|
||||
for (int i = 0; i < in_shape.size(); i++) {
|
||||
if (axisIdx < axes_.size() && axes_[axisIdx] == i) {
|
||||
MS_ASSERT(in_shape[i] == 1);
|
||||
axisIdx++;
|
||||
continue;
|
||||
} else {
|
||||
out_shape.push_back(in_shape[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
outputs_.front()->set_shape(out_shape);
|
||||
outputs_.front()->set_data_type(in_tensor->data_type());
|
||||
outputs_.front()->SetFormat(in_tensor->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,93 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/stack.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Stack::GetAxis() const { return this->primitive->value.AsStack()->axis; }
|
||||
int Stack::GetN() const { return this->primitive->value.AsStack()->n; }
|
||||
std::vector<int> Stack::GetIsScale() const { return this->primitive->value.AsStack()->isScale; }
|
||||
|
||||
void Stack::SetAxis(int axis) { this->primitive->value.AsStack()->axis = axis; }
|
||||
void Stack::SetN(int n) { this->primitive->value.AsStack()->n = n; }
|
||||
void Stack::SetIsScale(const std::vector<int> &is_scale) { this->primitive->value.AsStack()->isScale = is_scale; }
|
||||
|
||||
#else
|
||||
|
||||
int Stack::GetAxis() const { return this->primitive->value_as_Stack()->axis(); }
|
||||
int Stack::GetN() const { return this->primitive->value_as_Stack()->n(); }
|
||||
std::vector<int> Stack::GetIsScale() const {
|
||||
auto fb_vector = this->primitive->value_as_Stack()->isScale();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
void Stack::SetAxis(int axis) {}
|
||||
void Stack::SetN(int n) {}
|
||||
void Stack::SetIsScale(const std::vector<int> &is_scale) {}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr int kStackOutputNum = 1;
|
||||
constexpr int kStackMinInputNum = 2;
|
||||
} // namespace
|
||||
|
||||
int Stack::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (outputs.size() != kStackOutputNum) {
|
||||
MS_LOG(ERROR) << "Invalid output size:" << outputs.size();
|
||||
return 1;
|
||||
}
|
||||
if (inputs.size() < kStackMinInputNum) {
|
||||
MS_LOG(ERROR) << "Invalid input size " << inputs.size();
|
||||
return 1;
|
||||
}
|
||||
auto input = inputs.at(0);
|
||||
auto input_shape = input->shape();
|
||||
|
||||
std::vector<int32_t> output_shape = input_shape;
|
||||
int axis = GetAxis() < 0 ? GetAxis() + input_shape.size() : GetAxis();
|
||||
if (axis < 0 || axis > input_shape.size()) {
|
||||
MS_LOG(ERROR) << "Invalid axis " << GetAxis();
|
||||
return 1;
|
||||
}
|
||||
schema::Format input0_format = input->GetFormat();
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
if (inputs[i]->GetFormat() != input0_format) {
|
||||
MS_LOG(ERROR) << "All inputs should have the same format!";
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto input_shape_tmp = inputs[i]->shape();
|
||||
if (input_shape_tmp.size() != input_shape.size()) {
|
||||
MS_LOG(ERROR) << "All input shape size should be the same!";
|
||||
return 1;
|
||||
}
|
||||
for (size_t j = 0; j < input_shape.size(); ++j) {
|
||||
if (input_shape_tmp[j] != input_shape[j]) {
|
||||
MS_LOG(ERROR) << "All input shape should be the same!";
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output_shape.insert(output_shape.begin() + axis, inputs.size());
|
||||
outputs[0]->set_shape(output_shape);
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,221 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/strided_slice.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int StridedSlice::GetBeginMask() const { return this->primitive->value.AsStridedSlice()->beginMask; }
|
||||
int StridedSlice::GetEndMask() const { return this->primitive->value.AsStridedSlice()->endMask; }
|
||||
int StridedSlice::GetEllipsisMask() const { return this->primitive->value.AsStridedSlice()->ellipsisMask; }
|
||||
int StridedSlice::GetNewAxisMask() const { return this->primitive->value.AsStridedSlice()->newAxisMask; }
|
||||
int StridedSlice::GetShrinkAxisMask() const { return this->primitive->value.AsStridedSlice()->shrinkAxisMask; }
|
||||
std::vector<int> StridedSlice::GetBegin() const { return this->primitive->value.AsStridedSlice()->begin; }
|
||||
std::vector<int> StridedSlice::GetEnd() const { return this->primitive->value.AsStridedSlice()->end; }
|
||||
std::vector<int> StridedSlice::GetStride() const { return this->primitive->value.AsStridedSlice()->stride; }
|
||||
std::vector<int> StridedSlice::GetIsScale() const { return this->primitive->value.AsStridedSlice()->isScale; }
|
||||
|
||||
void StridedSlice::SetBeginMask(int begin_mask) { this->primitive->value.AsStridedSlice()->beginMask = begin_mask; }
|
||||
void StridedSlice::SetEndMask(int end_mask) { this->primitive->value.AsStridedSlice()->endMask = end_mask; }
|
||||
void StridedSlice::SetEllipsisMask(int ellipsis_mask) {
|
||||
this->primitive->value.AsStridedSlice()->ellipsisMask = ellipsis_mask;
|
||||
}
|
||||
void StridedSlice::SetNewAxisMask(int new_axis_mask) {
|
||||
this->primitive->value.AsStridedSlice()->newAxisMask = new_axis_mask;
|
||||
}
|
||||
void StridedSlice::SetShrinkAxisMask(int shrink_axis_mask) {
|
||||
this->primitive->value.AsStridedSlice()->shrinkAxisMask = shrink_axis_mask;
|
||||
}
|
||||
void StridedSlice::SetBegin(const std::vector<int> &begin) { this->primitive->value.AsStridedSlice()->begin = begin; }
|
||||
void StridedSlice::SetEnd(const std::vector<int> &end) { this->primitive->value.AsStridedSlice()->end = end; }
|
||||
void StridedSlice::SetStride(const std::vector<int> &stride) {
|
||||
this->primitive->value.AsStridedSlice()->stride = stride;
|
||||
}
|
||||
void StridedSlice::SetIsScale(const std::vector<int> &is_scale) {
|
||||
this->primitive->value.AsStridedSlice()->isScale = is_scale;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
int StridedSlice::GetBeginMask() const { return this->primitive->value_as_StridedSlice()->beginMask(); }
|
||||
int StridedSlice::GetEndMask() const { return this->primitive->value_as_StridedSlice()->endMask(); }
|
||||
int StridedSlice::GetEllipsisMask() const { return this->primitive->value_as_StridedSlice()->ellipsisMask(); }
|
||||
int StridedSlice::GetNewAxisMask() const { return this->primitive->value_as_StridedSlice()->newAxisMask(); }
|
||||
int StridedSlice::GetShrinkAxisMask() const { return this->primitive->value_as_StridedSlice()->shrinkAxisMask(); }
|
||||
std::vector<int> StridedSlice::GetBegin() const {
|
||||
auto fb_vector = this->primitive->value_as_StridedSlice()->begin();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
std::vector<int> StridedSlice::GetEnd() const {
|
||||
auto fb_vector = this->primitive->value_as_StridedSlice()->end();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
std::vector<int> StridedSlice::GetStride() const {
|
||||
auto fb_vector = this->primitive->value_as_StridedSlice()->stride();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
std::vector<int> StridedSlice::GetIsScale() const {
|
||||
auto fb_vector = this->primitive->value_as_StridedSlice()->isScale();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
void StridedSlice::SetBeginMask(int begin_mask) {}
|
||||
void StridedSlice::SetEndMask(int end_mask) {}
|
||||
void StridedSlice::SetEllipsisMask(int ellipsis_mask) {}
|
||||
void StridedSlice::SetNewAxisMask(int new_axis_mask) {}
|
||||
void StridedSlice::SetShrinkAxisMask(int shrink_axis_mask) {}
|
||||
void StridedSlice::SetBegin(const std::vector<int> &begin) {}
|
||||
void StridedSlice::SetEnd(const std::vector<int> &end) {}
|
||||
void StridedSlice::SetStride(const std::vector<int> &stride) {}
|
||||
void StridedSlice::SetIsScale(const std::vector<int> &is_scale) {}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr int kStridedSliceOutputNum = 1;
|
||||
constexpr int kStridedSliceInputNum = 1;
|
||||
} // namespace
|
||||
|
||||
void StridedSlice::ApplyNewAxisMask() {
|
||||
for (int i = 0; i < new_axis_mask_.size(); i++) {
|
||||
if (new_axis_mask_.at(i)) {
|
||||
ndim_ += 1;
|
||||
in_shape_.insert(in_shape_.begin() + i, 1);
|
||||
begins_.at(i) = 0;
|
||||
ends_.at(i) = 1;
|
||||
strides_.at(i) = 1;
|
||||
|
||||
begins_.emplace_back(0);
|
||||
ends_.emplace_back(in_shape_.at(ndim_ - 1));
|
||||
strides_.emplace_back(1);
|
||||
|
||||
begins_mask_.at(i) = false;
|
||||
ends_mask_.at(i) = false;
|
||||
ellipsis_mask_.at(i) = false;
|
||||
shrink_axis_mask_.at(i) = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> StridedSlice::ApplyShrinkMask(std::vector<int> out_shape) {
|
||||
auto old_out_shape = out_shape;
|
||||
out_shape.clear();
|
||||
for (int i = 0; i < shrink_axis_mask_.size(); i++) {
|
||||
if (shrink_axis_mask_.at(i)) {
|
||||
ends_.at(i) = begins_.at(i) + 1;
|
||||
strides_.at(i) = 1;
|
||||
} else {
|
||||
out_shape.emplace_back(old_out_shape.at(i));
|
||||
}
|
||||
}
|
||||
for (int i = shrink_axis_mask_.size(); i < old_out_shape.size(); i++) {
|
||||
out_shape.emplace_back(old_out_shape.at(i));
|
||||
}
|
||||
return out_shape;
|
||||
}
|
||||
|
||||
/*only one bit will be used if multiple bits are true.*/
|
||||
void StridedSlice::ApplyEllipsisMask() {
|
||||
for (int i = 0; i < ellipsis_mask_.size(); i++) {
|
||||
if (ellipsis_mask_.at(i)) {
|
||||
begins_.at(i) = 0;
|
||||
ends_.at(i) = in_shape_.at(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void StridedSlice::ApplyBeginMask() {
|
||||
for (int i = 0; i < ndim_; i++) {
|
||||
if (begins_mask_.at(i)) {
|
||||
begins_.at(i) = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void StridedSlice::ApplyEndMask() {
|
||||
for (int i = 0; i < ndim_; i++) {
|
||||
if (ends_mask_.at(i)) {
|
||||
ends_.at(i) = in_shape_.at(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int StridedSlice::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (outputs.size() != kStridedSliceOutputNum) {
|
||||
MS_LOG(ERROR) << "Invalid output size:" << outputs.size();
|
||||
return 1;
|
||||
}
|
||||
if (inputs.size() != kStridedSliceInputNum) {
|
||||
MS_LOG(ERROR) << "Invalid input size " << inputs.size();
|
||||
return 1;
|
||||
}
|
||||
auto input = inputs.at(0);
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto input_shape = input->shape();
|
||||
std::vector<int> output_shape;
|
||||
ndim_ = static_cast<int>(GetBegin().size());
|
||||
|
||||
MS_ASSERT(ndim_ == static_cast<int>(strided_slice_prim->end()->size()));
|
||||
MS_ASSERT(ndim_ == static_cast<int>(strided_slice_prim->stride()->size()));
|
||||
MS_ASSERT(ndim_ == static_cast<int>(input_shape.size()));
|
||||
|
||||
for (int i = 0; i < ndim_; i++) {
|
||||
in_shape_.emplace_back(input_shape.at(i));
|
||||
begins_.emplace_back((GetBegin())[i]);
|
||||
ends_.emplace_back((GetEnd())[i]);
|
||||
strides_.emplace_back((GetStride())[i]);
|
||||
}
|
||||
|
||||
// set all mask to original input shape
|
||||
begins_mask_.resize(ndim_);
|
||||
ends_mask_.resize(ndim_);
|
||||
ellipsis_mask_.resize(ndim_);
|
||||
new_axis_mask_.resize(ndim_);
|
||||
shrink_axis_mask_.resize(ndim_);
|
||||
|
||||
// convert bit to vector
|
||||
for (int i = 0; i < ndim_; i++) {
|
||||
begins_mask_.at(i) = static_cast<uint32_t>(GetBeginMask()) & (1 << i);
|
||||
ends_mask_.at(i) = static_cast<uint32_t>(GetEndMask()) & (1 << i);
|
||||
ellipsis_mask_.at(i) = static_cast<uint32_t>(GetEllipsisMask()) & (1 << i);
|
||||
new_axis_mask_.at(i) = static_cast<uint32_t>(GetNewAxisMask()) & (1 << i);
|
||||
shrink_axis_mask_.at(i) = static_cast<uint32_t>(GetShrinkAxisMask()) & (1 << i);
|
||||
}
|
||||
|
||||
ApplyNewAxisMask();
|
||||
ApplyBeginMask();
|
||||
ApplyEndMask();
|
||||
ApplyEllipsisMask();
|
||||
|
||||
output_shape.clear();
|
||||
output_shape.resize(in_shape_.size());
|
||||
for (int i = 0; i < in_shape_.size(); i++) {
|
||||
if (i < ndim_ && new_axis_mask_.at(i)) {
|
||||
output_shape.at(i) = 1;
|
||||
} else {
|
||||
output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i);
|
||||
}
|
||||
}
|
||||
|
||||
output_shape = ApplyShrinkMask(output_shape);
|
||||
|
||||
outputs.front()->set_shape(output_shape);
|
||||
outputs.front()->set_data_type(input->data_type());
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,55 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/tile.h"
|
||||
#include <algorithm>
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
std::vector<int> Tile::GetMultiples() const { return this->primitive->value.AsTile()->multiples; }
|
||||
|
||||
void Tile::SetMultiples(const std::vector<int> &multiples) { this->primitive->value.AsTile()->multiples = multiples; }
|
||||
|
||||
#else
|
||||
|
||||
std::vector<int> Tile::GetMultiples() const {
|
||||
auto fb_vector = this->primitive->value_as_Tile()->multiples();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
void Tile::SetMultiples(const std::vector<int> &multiples) {}
|
||||
#endif
|
||||
int Tile::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
std::vector<int> out_shape;
|
||||
std::vector<int> multiples;
|
||||
std::copy(GetMultiples().begin(), GetMultiples().end(), std::back_inserter(multiples));
|
||||
for (size_t i = 0; i < input->shape().size(); ++i) {
|
||||
int tmp = input->shape()[i] * multiples[i];
|
||||
out_shape.push_back(tmp);
|
||||
}
|
||||
|
||||
output->SetFormat(input->GetFormat());
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,62 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/topk.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int TopK::GetK() const { return this->primitive->value.AsTopK()->k; }
|
||||
bool TopK::GetSorted() const { return this->primitive->value.AsTopK()->sorted; }
|
||||
|
||||
void TopK::SetK(int k) { this->primitive->value.AsTopK()->k = k; }
|
||||
void TopK::SetSorted(bool sorted) { this->primitive->value.AsTopK()->sorted = sorted; }
|
||||
|
||||
#else
|
||||
|
||||
int TopK::GetK() const { return this->primitive->value_as_TopK()->k(); }
|
||||
bool TopK::GetSorted() const { return this->primitive->value_as_TopK()->sorted(); }
|
||||
|
||||
void TopK::SetK(int k) {}
|
||||
void TopK::SetSorted(bool sorted) {}
|
||||
#endif
|
||||
int TopK::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (inputs_.size() != kSingleNum || outputs_.size() != kDoubleNum) {
|
||||
MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size();
|
||||
return 1;
|
||||
}
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output0 = outputs_.front();
|
||||
MS_ASSERT(output0 != nullptr);
|
||||
auto output1 = outputs_.at(1);
|
||||
MS_ASSERT(output1 != nullptr);
|
||||
|
||||
MS_ASSERT(topk_prim != nullptr);
|
||||
|
||||
auto out_shape = input->shape();
|
||||
out_shape[out_shape.size() - 1] = GetK();
|
||||
|
||||
output0->set_shape(out_shape);
|
||||
output0->set_data_type(input->data_type());
|
||||
output0->SetFormat(input->GetFormat());
|
||||
|
||||
output1->set_shape(out_shape);
|
||||
output1->set_data_type(kNumberTypeInt32);
|
||||
output1->SetFormat(input->GetFormat());
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,69 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/transpose.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
std::vector<int> Transpose::GetPerm() const { return this->primitive->value.AsTranspose()->perm; }
|
||||
bool Transpose::GetConjugate() const { return this->primitive->value.AsTranspose()->conjugate; }
|
||||
|
||||
void Transpose::SetPerm(const std::vector<int> &perm) { this->primitive->value.AsTranspose()->perm = perm; }
|
||||
void Transpose::SetConjugate(bool conjugate) { this->primitive->value.AsTranspose()->conjugate = conjugate; }
|
||||
|
||||
#else
|
||||
|
||||
std::vector<int> Transpose::GetPerm() const {
|
||||
auto fb_vector = this->primitive->value_as_Transpose()->perm();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
bool Transpose::GetConjugate() const { return this->primitive->value_as_Transpose()->conjugate(); }
|
||||
|
||||
void Transpose::SetPerm(const std::vector<int> &perm) {}
|
||||
void Transpose::SetConjugate(bool conjugate) {}
|
||||
#endif
|
||||
int Transpose::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
MS_ASSERT(inputs_.size() == kSingleNum);
|
||||
MS_ASSERT(outputs_.size() == kSingleNum);
|
||||
|
||||
int conjugate = GetConjugate();
|
||||
if (conjugate) {
|
||||
MS_LOG(ERROR) << "Transpose conjugate is not support currently";
|
||||
return 1;
|
||||
}
|
||||
std::vector<int> perm;
|
||||
perm.insert(perm.begin(), GetPerm().begin(), GetPerm().end());
|
||||
|
||||
std::vector<int> in_shape = input->shape();
|
||||
std::vector<int> out_shape;
|
||||
out_shape.resize(perm.size());
|
||||
for (int i = 0; i < perm.size(); ++i) {
|
||||
out_shape[i] = in_shape[perm[i]];
|
||||
}
|
||||
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,52 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/unique.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Unique::GetOutType() const { return this->primitive->value.AsUnique()->outType; }
|
||||
|
||||
void Unique::SetOutType(int out_type) { this->primitive->value.AsUnique()->outType = out_type; }
|
||||
|
||||
#else
|
||||
|
||||
int Unique::GetOutType() const { return this->primitive->value_as_Unique()->outType(); }
|
||||
|
||||
void Unique::SetOutType(int out_type) {}
|
||||
#endif
|
||||
int Unique::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (inputs_.size() != kSingleNum || outputs_.size() != kDoubleNum) {
|
||||
MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size();
|
||||
return 1;
|
||||
}
|
||||
auto &input = inputs_.at(0);
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto &output0 = outputs_.at(0);
|
||||
MS_ASSERT(output0 != nullptr);
|
||||
auto &output1 = outputs_.at(1);
|
||||
MS_ASSERT(output1 != nullptr);
|
||||
output0->set_shape(input->shape());
|
||||
output0->set_data_type(input->data_type());
|
||||
output1->set_shape(input->shape());
|
||||
output1->set_data_type(kNumberTypeInt32);
|
||||
output1->SetFormat(input->GetFormat());
|
||||
output0->SetFormat(input->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,80 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/unsqueeze.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
std::vector<int> Unsqueeze::GetAxis() const { return this->primitive->value.AsUnsqueeze()->axis; }
|
||||
|
||||
void Unsqueeze::SetAxis(const std::vector<int> &axis) { this->primitive->value.AsUnsqueeze()->axis = axis; }
|
||||
|
||||
#else
|
||||
bool predicate(int n) { return n != 1; }
|
||||
std::vector<int> Unsqueeze::GetAxis() const {
|
||||
auto fb_vector = this->primitive->value_as_Unsqueeze()->axis();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
void Unsqueeze::SetAxis(const std::vector<int> &axis) {}
|
||||
#endif
|
||||
int Unsqueeze::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
if (inputs_.size() != kSingleNum) {
|
||||
MS_LOG(ERROR) << "input size is invalid";
|
||||
}
|
||||
if (outputs_.size() != kSingleNum) {
|
||||
MS_LOG(ERROR) << "output size is invalid";
|
||||
}
|
||||
|
||||
auto dims = GetAxis().data();
|
||||
auto in_shape = input->shape();
|
||||
auto in_rank = in_shape.size();
|
||||
auto dim_rank = GetAxis().size();
|
||||
std::vector<int> out_shape;
|
||||
|
||||
if (dim_rank == 0) {
|
||||
std::copy_if(in_shape.begin(), in_shape.end(), out_shape.begin(), [](int n) -> bool { return n != 1; });
|
||||
} else {
|
||||
auto sz = in_rank + dim_rank;
|
||||
int in_itr = 0;
|
||||
int ax_itr = 0;
|
||||
for (int i = 0; i < sz; i++) {
|
||||
if (ax_itr < dim_rank && dims[ax_itr] == i) {
|
||||
out_shape.emplace_back(1);
|
||||
ax_itr++;
|
||||
} else if (ax_itr < dim_rank && dims[ax_itr] + sz == i) {
|
||||
out_shape.emplace_back(1);
|
||||
ax_itr++;
|
||||
} else {
|
||||
if (in_shape[in_itr] > 1) {
|
||||
out_shape.emplace_back(in_shape[in_itr]);
|
||||
}
|
||||
in_itr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output->SetFormat(input->GetFormat());
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,59 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/unstack.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Unstack::GetNum() const { return this->primitive->value.AsUnstack()->num; }
|
||||
int Unstack::GetAxis() const { return this->primitive->value.AsUnstack()->axis; }
|
||||
|
||||
void Unstack::SetNum(int num) { this->primitive->value.AsUnstack()->num = num; }
|
||||
void Unstack::SetAxis(int axis) { this->primitive->value.AsUnstack()->axis = axis; }
|
||||
|
||||
#else
|
||||
|
||||
int Unstack::GetNum() const { return this->primitive->value_as_Unstack()->num(); }
|
||||
int Unstack::GetAxis() const { return this->primitive->value_as_Unstack()->axis(); }
|
||||
|
||||
void Unstack::SetNum(int num) {}
|
||||
void Unstack::SetAxis(int axis) {}
|
||||
#endif
|
||||
int Unstack::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
|
||||
auto input = inputs.at(0);
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto input_shape = input->shape();
|
||||
int axis = GetAxis() < 0 ? GetAxis() + input_shape.size() : GetAxis();
|
||||
if (axis < 0 || axis >= input_shape.size()) {
|
||||
MS_LOG(ERROR) << "Invalid axis " << GetAxis();
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<int> output_shape;
|
||||
for (size_t i = 0; i < input_shape.size(); ++i) {
|
||||
if (i != axis) {
|
||||
output_shape.push_back(input_shape.at(i));
|
||||
}
|
||||
}
|
||||
for (auto &out : outputs) {
|
||||
MS_ASSERT(out != nullptr);
|
||||
out->set_shape(output_shape);
|
||||
out->set_data_type(input->data_type());
|
||||
out->SetFormat(input->GetFormat());
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,93 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/where.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
std::vector<bool> Where::GetCondition() const { return this->primitive->value.AsWhere()->condition; }
|
||||
|
||||
void Where::SetCondition(const std::vector<bool> &condition) {
|
||||
this->primitive->value.AsWhere()->condition = condition;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
std::vector<bool> Where::GetCondition() const {
|
||||
auto fb_vector = this->primitive->value_as_Where()->condition();
|
||||
return std::vector<bool>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
void Where::SetCondition(const std::vector<bool> &condition) {}
|
||||
#endif
|
||||
int Where::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
|
||||
MS_LOG(ERROR) << "where input or output number invalid, Input size:" << inputs_.size()
|
||||
<< ", output size: " << outputs_.size();
|
||||
return 1;
|
||||
}
|
||||
if (inputs_.size() < 3) {
|
||||
MS_LOG(ERROR) << "Input shape tensors should b";
|
||||
return 1;
|
||||
}
|
||||
auto input0 = inputs_.at(0);
|
||||
auto input1 = inputs_.at(1);
|
||||
auto input2 = inputs_.at(2);
|
||||
int num = input0->ElementsNum();
|
||||
int num1 = input1->ElementsNum();
|
||||
int num2 = input2->ElementsNum();
|
||||
int nummax = num > num1 ? num : (num1 > num2 ? num1 : num2);
|
||||
|
||||
auto shape_tmp = inputs_.at(0)->shape();
|
||||
auto shape_tmp1 = inputs_.at(1)->shape();
|
||||
auto shape_tmp2 = inputs_.at(2)->shape();
|
||||
int axisout = 0;
|
||||
int temp = 0;
|
||||
for (int j = 0; j < shape_tmp.size(); j++) {
|
||||
if (shape_tmp[j] == shape_tmp1[j] && shape_tmp[j] != shape_tmp2[j]) {
|
||||
axisout = j;
|
||||
break;
|
||||
}
|
||||
if (shape_tmp[j] == shape_tmp2[j] && shape_tmp[j] != shape_tmp1[j]) {
|
||||
axisout = j;
|
||||
break;
|
||||
}
|
||||
if (shape_tmp1[j] == shape_tmp2[j] && shape_tmp[j] != shape_tmp1[j]) {
|
||||
axisout = j;
|
||||
break;
|
||||
}
|
||||
temp += 1;
|
||||
if (temp == shape_tmp.size()) {
|
||||
outputs_[0]->set_shape(shape_tmp);
|
||||
output->set_data_type(input->data_type());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
auto output_shape = shape_tmp;
|
||||
output_shape[axisout] = nummax;
|
||||
outputs_[0]->set_shape(output_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -21,6 +21,7 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
#define MS_API __attribute__((visibility("default")))
|
||||
|
@ -34,7 +35,7 @@ class ModelImpl;
|
|||
/// \brief Primitive defined as prototype of operator.
|
||||
///
|
||||
/// \note List public class and interface for reference.
|
||||
class Primitive;
|
||||
class PrimitiveC;
|
||||
|
||||
/// \brief Model defined model in MindSpore Lite for managing graph.
|
||||
class MS_API Model {
|
||||
|
@ -60,7 +61,7 @@ class MS_API Model {
|
|||
/// \param[in] name Define name of primitive to be returned.
|
||||
///
|
||||
/// \return the pointer of MindSpore Lite Primitive.
|
||||
lite::Primitive *GetOp(const std::string &name) const;
|
||||
PrimitiveC *GetOp(const std::string &name) const;
|
||||
|
||||
/// \brief Get graph defined in flatbuffers.
|
||||
///
|
||||
|
@ -97,7 +98,7 @@ class MS_API ModelBuilder {
|
|||
/// \param[in] inputs Define input edge of primitive to be added.
|
||||
///
|
||||
/// \return ID of the added primitive.
|
||||
virtual std::string AddOp(const lite::Primitive &op, const std::vector<OutEdge> &inputs) = 0;
|
||||
virtual std::string AddOp(const PrimitiveC &op, const std::vector<OutEdge> &inputs) = 0;
|
||||
|
||||
/// \brief Finish constructing the model.
|
||||
///
|
||||
|
|
|
@ -1,64 +1,64 @@
|
|||
set(LITE_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/graph_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/ms_tensor_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/allocator.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/runtime_api.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/thread_pool.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/workspace_pool.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ir/tensor.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/context.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/executor.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/populate_parameter.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc
|
||||
)
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/graph_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/ms_tensor_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/allocator.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/runtime_api.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/thread_pool.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/workspace_pool.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ir/tensor.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/context.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/executor.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/populate_parameter.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc
|
||||
)
|
||||
|
||||
if (SUPPORT_GPU)
|
||||
list(APPEND LITE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/subgraph_opencl_kernel.cc)
|
||||
list(APPEND LITE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/utils.cc)
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
set(LITE_SRC
|
||||
${LITE_SRC}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/model.cc
|
||||
)
|
||||
${LITE_SRC}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/model.cc
|
||||
)
|
||||
|
||||
if (SUPPORT_GPU)
|
||||
set(LITE_SRC
|
||||
${LITE_SRC}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_executor.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_allocator.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_runtime.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_wrapper.cc
|
||||
)
|
||||
set(LITE_SRC
|
||||
${LITE_SRC}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_executor.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_allocator.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_runtime.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_wrapper.cc
|
||||
)
|
||||
endif ()
|
||||
|
||||
set(ANF_SRC
|
||||
${ANF_SRC}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ir/meta_tensor_extends.cc
|
||||
)
|
||||
${ANF_SRC}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ir/meta_tensor_extends.cc
|
||||
)
|
||||
|
||||
add_library(mindspore-lite SHARED ${LITE_SRC} ${ANF_SRC})
|
||||
target_link_libraries(mindspore-lite
|
||||
cpu_kernel_mid_
|
||||
ops_mid_
|
||||
)
|
||||
cpu_kernel_mid_
|
||||
c_ops_mid
|
||||
)
|
||||
|
||||
add_subdirectory(runtime/kernel/arm)
|
||||
if (PLATFORM_ARM32 OR PLATFORM_ARM64)
|
||||
target_link_libraries(mindspore-lite log)
|
||||
endif()
|
||||
target_link_libraries(mindspore-lite log)
|
||||
endif ()
|
||||
if (BUILD_MINDDATA)
|
||||
target_link_libraries(mindspore-lite minddata-eager minddata-lite)
|
||||
target_link_libraries(mindspore-lite minddata-eager minddata-lite)
|
||||
endif ()
|
||||
|
||||
add_subdirectory(ops)
|
||||
|
||||
if("${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND (PLATFORM_ARM64 OR PLATFORM_ARM32))
|
||||
add_custom_command(TARGET mindspore-lite POST_BUILD
|
||||
COMMAND ${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip
|
||||
${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so)
|
||||
endif()
|
||||
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND (PLATFORM_ARM64 OR PLATFORM_ARM32))
|
||||
add_custom_command(TARGET mindspore-lite POST_BUILD
|
||||
COMMAND ${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip
|
||||
${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so)
|
||||
endif ()
|
||||
|
||||
|
|
|
@ -49,7 +49,13 @@ int Executor::Run(std::vector<tensor::Tensor *> &in_tensors, std::vector<tensor:
|
|||
MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name();
|
||||
return ret;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "out_tensors";
|
||||
auto tensors = kernel->out_tensors();
|
||||
MS_LOG(INFO) << kernel->name();
|
||||
for (int i = 0; i < tensors.size(); ++i) {
|
||||
auto tensor = tensors[i];
|
||||
MS_LOG(INFO) << tensor->ToString();
|
||||
}
|
||||
if (after != nullptr) {
|
||||
if (!after(PackToMSTensors(kernel->in_tensors()), PackToMSTensors(kernel->out_tensors()),
|
||||
{kernel->name(), kernel->type_str()})) {
|
||||
|
|
|
@ -1,19 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ir/primitive_value.h"
|
||||
|
||||
|
|
@ -1,47 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVE_H_
|
||||
#define MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVE_H_
|
||||
|
||||
#include "ir/value.h"
|
||||
#include "src/ops/ops.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class PrimitiveValue : public Value {
|
||||
public:
|
||||
explicit PrimitiveValue(const lite::Primitive *prim) : primitive(prim) {}
|
||||
|
||||
const lite::Primitive *GetPrimitive() const {
|
||||
return this->primitive;
|
||||
}
|
||||
MS_DECLARE_PARENT(PrimitiveValue, Value)
|
||||
bool operator==(const Value &rhs) const override {
|
||||
if (rhs.isa<PrimitiveValue>()) {
|
||||
auto other_prim = static_cast<const PrimitiveValue &>(rhs);
|
||||
return *this == other_prim;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
const lite::Primitive *primitive = nullptr;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVE_H_
|
||||
|
|
@ -124,13 +124,13 @@ const kernel::KernelCreator *KernelRegistry::GetCreatorArrays() { return creator
|
|||
|
||||
kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<tensor::Tensor *> &in_tensors,
|
||||
const std::vector<tensor::Tensor *> &out_tensors,
|
||||
const lite::Primitive *primitive, const Context *ctx,
|
||||
const PrimitiveC *primitive, const Context *ctx,
|
||||
const kernel::KernelKey &key) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
MS_EXCEPTION_IF_NULL(ctx);
|
||||
auto parameter = kernel::PopulateParameter(primitive);
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(primitive->Type());
|
||||
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << primitive->Type();
|
||||
return nullptr;
|
||||
}
|
||||
auto creator = GetCreator(key);
|
||||
|
|
|
@ -40,7 +40,7 @@ class KernelRegistry {
|
|||
kernel::KernelCreator creator);
|
||||
bool Merge(const std::unordered_map<kernel::KernelKey, kernel::KernelCreator> &newCreators);
|
||||
kernel::LiteKernel *GetKernel(const std::vector<tensor::Tensor *> &in_tensors,
|
||||
const std::vector<tensor::Tensor *> &out_tensors, const lite::Primitive *primitive,
|
||||
const std::vector<tensor::Tensor *> &out_tensors, const PrimitiveC *primitive,
|
||||
const Context *ctx, const kernel::KernelKey &key);
|
||||
|
||||
protected:
|
||||
|
|
|
@ -24,8 +24,8 @@
|
|||
#include "src/runtime/kernel/arm/nnacl/op_base.h"
|
||||
#include "include/context.h"
|
||||
#include "src/ir/tensor.h"
|
||||
#include "src/ops/ops.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
#ifdef ENABLE_FP16
|
||||
using FLOAT_t = float16_t;
|
||||
|
@ -59,7 +59,7 @@ class LiteKernel {
|
|||
LiteKernel() = default;
|
||||
LiteKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &in_tensors,
|
||||
const std::vector<lite::tensor::Tensor *> &out_tensors, const lite::Context *ctx,
|
||||
const lite::Primitive *primitive)
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: op_parameter_(parameter),
|
||||
in_tensors_(in_tensors),
|
||||
out_tensors_(out_tensors),
|
||||
|
@ -81,7 +81,7 @@ class LiteKernel {
|
|||
|
||||
virtual int Prepare() {
|
||||
if (!InferShapeDone()) {
|
||||
(const_cast<lite::Primitive *>(primitive_))->InferShape(in_tensors_, out_tensors_);
|
||||
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->InferShape(in_tensors_, out_tensors_);
|
||||
if (need_reinit_) {
|
||||
Init();
|
||||
}
|
||||
|
@ -154,7 +154,7 @@ class LiteKernel {
|
|||
|
||||
void set_need_reinit() { need_reinit_ = true; }
|
||||
|
||||
const lite::Primitive *GetPrimitive() const { return primitive_; }
|
||||
const mindspore::lite::PrimitiveC *GetPrimitive() const { return primitive_; }
|
||||
|
||||
protected:
|
||||
bool InferShapeDone() { return !(primitive_ != nullptr && !primitive_->GetInferFlag()) && true; }
|
||||
|
@ -162,7 +162,7 @@ class LiteKernel {
|
|||
KernelKey desc_;
|
||||
std::string name_;
|
||||
OpParameter *op_parameter_ = nullptr;
|
||||
const lite::Primitive *primitive_ = nullptr;
|
||||
const mindspore::lite::PrimitiveC *primitive_ = nullptr;
|
||||
const lite::Context *context_ = nullptr;
|
||||
// tensor will free in ~lite_session()
|
||||
std::vector<lite::tensor::Tensor *> in_tensors_;
|
||||
|
@ -181,7 +181,7 @@ class SubGraphKernel : public LiteKernel {
|
|||
const std::vector<kernel::LiteKernel *> &in_kernels,
|
||||
const std::vector<kernel::LiteKernel *> &out_kernels,
|
||||
const std::vector<kernel::LiteKernel *> &nodes, const lite::Context *ctx,
|
||||
const lite::Primitive *primitive)
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(nullptr, inputs, outputs, ctx, primitive), nodes_(nodes) {
|
||||
in_kernels_ = in_kernels;
|
||||
out_kernels_ = out_kernels;
|
||||
|
@ -198,7 +198,8 @@ class SubGraphKernel : public LiteKernel {
|
|||
|
||||
typedef LiteKernel *(*KernelCreator)(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, OpParameter *parameter,
|
||||
const lite::Context *ctx, const KernelKey &desc, const lite::Primitive *primitive);
|
||||
const lite::Context *ctx, const KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive);
|
||||
|
||||
class LiteKernelUtil {
|
||||
public:
|
||||
|
|
|
@ -14,9 +14,101 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/unique.h"
|
||||
#include "src/ops/space_to_batch.h"
|
||||
#include "src/ops/conv2d.h"
|
||||
#include "src/ops/roi_pooling.h"
|
||||
#include "src/ops/topk.h"
|
||||
#include "src/ops/broadcast_to.h"
|
||||
#include "src/ops/unsqueeze.h"
|
||||
#include "src/ops/unstack.h"
|
||||
#include "src/ops/depth_to_space.h"
|
||||
#include "src/ops/batch_to_space.h"
|
||||
#include "src/ops/prior_box.h"
|
||||
#include "src/ops/lstm.h"
|
||||
#include "src/ops/softmax.h"
|
||||
#include "src/ops/activation.h"
|
||||
#include "src/ops/deconv2d.h"
|
||||
#include "src/ops/reduce.h"
|
||||
#include "src/ops/pooling.h"
|
||||
#include "src/ops/fused_batchnorm.h"
|
||||
#include "src/ops/batch_norm.h"
|
||||
#include "src/ops/power.h"
|
||||
#include "src/ops/range.h"
|
||||
#include "src/ops/add.h"
|
||||
#include "src/ops/sub.h"
|
||||
#include "src/ops/div.h"
|
||||
#include "src/ops/bias_add.h"
|
||||
#include "src/ops/expand_dims.h"
|
||||
#include "src/ops/full_connection.h"
|
||||
#include "src/ops/shape.h"
|
||||
#include "src/ops/elu.h"
|
||||
#include "src/ops/embedding_lookup.h"
|
||||
#include "src/ops/quant_dtype_cast.h"
|
||||
#include "src/ops/matmul.h"
|
||||
#include "src/ops/resize.h"
|
||||
#include "src/ops/tile.h"
|
||||
#include "src/ops/one_hot.h"
|
||||
#include "src/ops/space_to_depth.h"
|
||||
#include "src/ops/split.h"
|
||||
#include "src/ops/argmax.h"
|
||||
#include "src/ops/argmin.h"
|
||||
#include "src/ops/cast.h"
|
||||
#include "src/ops/reshape.h"
|
||||
#include "src/ops/scale.h"
|
||||
#include "src/ops/concat.h"
|
||||
#include "src/ops/nchw2nhwc.h"
|
||||
#include "src/ops/slice.h"
|
||||
#include "src/ops/squeeze.h"
|
||||
#include "src/ops/flatten.h"
|
||||
#include "src/ops/mean.h"
|
||||
#include "src/ops/nhwc2nchw.h"
|
||||
#include "src/ops/stack.h"
|
||||
#include "src/ops/crop.h"
|
||||
#include "src/ops/addn.h"
|
||||
#include "src/ops/gather.h"
|
||||
#include "src/ops/gather_nd.h"
|
||||
#include "src/ops/local_response_normalization.h"
|
||||
#include "src/ops/pad.h"
|
||||
#include "src/ops/prelu.h"
|
||||
#include "src/ops/caffe_p_relu.h"
|
||||
#include "src/ops/reverse_sequence.h"
|
||||
#include "src/ops/dedepthwise_conv2d.h"
|
||||
#include "src/ops/depthwise_conv2d.h"
|
||||
#include "src/ops/mul.h"
|
||||
#include "src/ops/eltwise.h"
|
||||
#include "src/ops/fill.h"
|
||||
#include "src/ops/transpose.h"
|
||||
#include "src/ops/log.h"
|
||||
#include "src/ops/abs.h"
|
||||
#include "src/ops/sin.h"
|
||||
#include "src/ops/cos.h"
|
||||
#include "src/ops/sqrt.h"
|
||||
#include "src/ops/square.h"
|
||||
#include "src/ops/exp.h"
|
||||
#include "src/ops/rsqrt.h"
|
||||
#include "src/ops/maximum.h"
|
||||
#include "src/ops/minimum.h"
|
||||
#include "src/ops/strided_slice.h"
|
||||
#include "src/ops/reverse.h"
|
||||
#include "src/ops/logical_and.h"
|
||||
#include "src/ops/logical_or.h"
|
||||
#include "src/ops/logical_not.h"
|
||||
#include "src/ops/floor_div.h"
|
||||
#include "src/ops/floor_mod.h"
|
||||
#include "src/ops/equal.h"
|
||||
#include "src/ops/not_equal.h"
|
||||
#include "src/ops/less.h"
|
||||
#include "src/ops/less_equal.h"
|
||||
#include "src/ops/greater_equal.h"
|
||||
#include "src/ops/greater.h"
|
||||
#include "src/ops/floor.h"
|
||||
#include "src/ops/squared_difference.h"
|
||||
#include "src/ops/ceil.h"
|
||||
#include "src/ops/round.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "include/model.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "src/ops/ops.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
|
||||
|
@ -28,19 +120,19 @@ class ModelImpl {
|
|||
meta_graph_ = schema::GetMetaGraph(model_buf);
|
||||
}
|
||||
virtual ~ModelImpl();
|
||||
lite::Primitive *GetOp(const std::string &name) const;
|
||||
PrimitiveC *GetOp(const std::string &name) const;
|
||||
const schema::MetaGraph *meta_graph() const;
|
||||
void FreeMetaGraph();
|
||||
int BuildOps();
|
||||
|
||||
protected:
|
||||
lite::Primitive *CopyPrimitive(const schema::Primitive *src_prim);
|
||||
PrimitiveC *CopyPrimitive(const schema::Primitive *src_prim);
|
||||
|
||||
protected:
|
||||
const char *model_buf_;
|
||||
size_t buf_size_;
|
||||
const schema::MetaGraph *meta_graph_ = nullptr;
|
||||
std::map<std::string, lite::Primitive *> ops_;
|
||||
std::map<std::string, PrimitiveC *> ops_;
|
||||
};
|
||||
|
||||
ModelImpl *ModelImpl::Import(const char *model_buf, size_t size) {
|
||||
|
@ -72,7 +164,7 @@ ModelImpl *ModelImpl::Import(const char *model_buf, size_t size) {
|
|||
return model;
|
||||
}
|
||||
|
||||
lite::Primitive *ModelImpl::GetOp(const std::string &name) const {
|
||||
PrimitiveC *ModelImpl::GetOp(const std::string &name) const {
|
||||
auto iter = ops_.find(name);
|
||||
if (iter == ops_.end()) {
|
||||
return nullptr;
|
||||
|
@ -96,178 +188,178 @@ void ModelImpl::FreeMetaGraph() {
|
|||
|
||||
const schema::MetaGraph *ModelImpl::meta_graph() const { return this->meta_graph_; }
|
||||
|
||||
lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *src_prim) {
|
||||
PrimitiveC *ModelImpl::CopyPrimitive(const schema::Primitive *src_prim) {
|
||||
MS_EXCEPTION_IF_NULL(src_prim);
|
||||
auto op_type = src_prim->value_type();
|
||||
switch (op_type) {
|
||||
case schema::PrimitiveType_SoftMax:
|
||||
return new lite::SoftMax(const_cast<schema::Primitive *>(src_prim));
|
||||
return new SoftMax(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Activation:
|
||||
return new lite::Activation(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Activation(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Conv2D:
|
||||
return new lite::Conv2D(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Conv2D(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_DeConv2D:
|
||||
return new lite::DeConv2D(const_cast<schema::Primitive *>(src_prim));
|
||||
return new DeConv2D(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Reduce:
|
||||
return new lite::Reduce(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Reduce(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Pooling:
|
||||
return new lite::Pooling(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Pooling(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_DepthwiseConv2D:
|
||||
return new lite::DepthwiseConv2D(const_cast<schema::Primitive *>(src_prim));
|
||||
return new DepthwiseConv2D(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_FusedBatchNorm:
|
||||
return new lite::FusedBatchNorm(const_cast<schema::Primitive *>(src_prim));
|
||||
return new FusedBatchNorm(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_BatchNorm:
|
||||
return new lite::BatchNorm(const_cast<schema::Primitive *>(src_prim));
|
||||
return new BatchNorm(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_FullConnection:
|
||||
return new lite::FullConnection(const_cast<schema::Primitive *>(src_prim));
|
||||
return new FullConnection(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Power:
|
||||
return new lite::Power(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Power(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Range:
|
||||
return new lite::Range(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Range(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Mul:
|
||||
return new lite::Mul(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Mul(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Add:
|
||||
return new lite::Add(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Add(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Sub:
|
||||
return new lite::Sub(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Sub(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Div:
|
||||
return new lite::Div(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Div(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_BiasAdd:
|
||||
return new lite::BiasAdd(const_cast<schema::Primitive *>(src_prim));
|
||||
return new BiasAdd(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_ExpandDims:
|
||||
return new lite::ExpandDims(const_cast<schema::Primitive *>(src_prim));
|
||||
return new ExpandDims(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_ArgMax:
|
||||
return new lite::ArgMax(const_cast<schema::Primitive *>(src_prim));
|
||||
return new ArgMax(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_ArgMin:
|
||||
return new lite::ArgMin(const_cast<schema::Primitive *>(src_prim));
|
||||
return new ArgMin(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Cast:
|
||||
return new lite::Cast(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Cast(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Reshape:
|
||||
return new lite::Reshape(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Reshape(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Scale:
|
||||
return new lite::Scale(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Scale(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Eltwise:
|
||||
return new lite::Eltwise(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Eltwise(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Concat:
|
||||
return new lite::Concat(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Concat(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Fill:
|
||||
return new lite::Fill(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Fill(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Transpose:
|
||||
return new lite::Transpose(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Transpose(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Slice:
|
||||
return new lite::Slice(const_cast<schema::Primitive *>(src_prim));
|
||||
return new SliceOp(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Squeeze:
|
||||
return new lite::Squeeze(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Squeeze(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Nchw2Nhwc:
|
||||
return new lite::Nchw2Nhwc(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Nchw2Nhwc(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Nhwc2Nchw:
|
||||
return new lite::Nhwc2Nchw(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Nhwc2Nchw(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Flatten:
|
||||
return new lite::Flatten(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Flatten(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Mean:
|
||||
return new lite::Mean(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Mean(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Stack:
|
||||
return new lite::Stack(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Stack(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Crop:
|
||||
return new lite::Crop(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Crop(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_SquaredDifference:
|
||||
return new lite::SquaredDifference(const_cast<schema::Primitive *>(src_prim));
|
||||
return new SquaredDifference(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_AddN:
|
||||
return new lite::AddN(const_cast<schema::Primitive *>(src_prim));
|
||||
return new AddN(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Abs:
|
||||
return new lite::Abs(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Abs(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Sin:
|
||||
return new lite::Sin(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Sin(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Cos:
|
||||
return new lite::Cos(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Cos(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Log:
|
||||
return new lite::Log(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Log(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Sqrt:
|
||||
return new lite::Sqrt(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Sqrt(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Rsqrt:
|
||||
return new lite::Rsqrt(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Rsqrt(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Square:
|
||||
return new lite::Square(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Square(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Exp:
|
||||
return new lite::Exp(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Exp(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Gather:
|
||||
return new lite::Gather(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Gather(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_GatherNd:
|
||||
return new lite::GatherNd(const_cast<schema::Primitive *>(src_prim));
|
||||
return new GatherNd(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_LocalResponseNormalization:
|
||||
return new lite::LocalResponseNormalization(const_cast<schema::Primitive *>(src_prim));
|
||||
return new LocalResponseNormalization(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Maximum:
|
||||
return new lite::Maximum(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Maximum(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Minimum:
|
||||
return new lite::Minimum(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Minimum(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Pad:
|
||||
return new lite::Pad(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Pad(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_StridedSlice:
|
||||
return new lite::StridedSlice(const_cast<schema::Primitive *>(src_prim));
|
||||
return new StridedSlice(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Prelu:
|
||||
return new lite::Prelu(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Prelu(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_CaffePReLU:
|
||||
return new lite::CaffePReLU(const_cast<schema::Primitive *>(src_prim));
|
||||
return new CaffePReLU(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Round:
|
||||
return new lite::Round(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Round(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Reverse:
|
||||
return new lite::Reverse(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Reverse(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_ReverseSequence:
|
||||
return new lite::ReverseSequence(const_cast<schema::Primitive *>(src_prim));
|
||||
return new ReverseSequence(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_LogicalAnd:
|
||||
return new lite::LogicalAnd(const_cast<schema::Primitive *>(src_prim));
|
||||
return new LogicalAnd(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_LogicalOr:
|
||||
return new lite::LogicalOr(const_cast<schema::Primitive *>(src_prim));
|
||||
return new LogicalOr(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_LogicalNot:
|
||||
return new lite::LogicalNot(const_cast<schema::Primitive *>(src_prim));
|
||||
return new LogicalNot(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_FloorDiv:
|
||||
return new lite::FloorDiv(const_cast<schema::Primitive *>(src_prim));
|
||||
return new FloorDiv(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_FloorMod:
|
||||
return new lite::FloorMod(const_cast<schema::Primitive *>(src_prim));
|
||||
return new FloorMod(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Equal:
|
||||
return new lite::Equal(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Equal(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_NotEqual:
|
||||
return new lite::NotEqual(const_cast<schema::Primitive *>(src_prim));
|
||||
return new NotEqual(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Less:
|
||||
return new lite::Less(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Less(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_LessEqual:
|
||||
return new lite::LessEqual(const_cast<schema::Primitive *>(src_prim));
|
||||
return new LessEqual(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Greater:
|
||||
return new lite::Greater(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Greater(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_GreaterEqual:
|
||||
return new lite::GreaterEqual(const_cast<schema::Primitive *>(src_prim));
|
||||
return new GreaterEqual(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Floor:
|
||||
return new lite::Floor(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Floor(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Ceil:
|
||||
return new lite::Ceil(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Ceil(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Split:
|
||||
return new lite::Split(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Split(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_OneHot:
|
||||
return new lite::OneHot(const_cast<schema::Primitive *>(src_prim));
|
||||
return new OneHot(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_SpaceToDepth:
|
||||
return new lite::SpaceToDepth(const_cast<schema::Primitive *>(src_prim));
|
||||
return new SpaceToDepth(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Tile:
|
||||
return new lite::Tile(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Tile(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Resize:
|
||||
return new lite::Resize(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Resize(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Unstack:
|
||||
return new lite::Unstack(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Unstack(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Unique:
|
||||
return new lite::Unique(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Unique(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_TopK:
|
||||
return new lite::TopK(const_cast<schema::Primitive *>(src_prim));
|
||||
return new TopK(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_MatMul:
|
||||
return new lite::MatMul(const_cast<schema::Primitive *>(src_prim));
|
||||
return new MatMul(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_QuantDTypeCast:
|
||||
return new lite::QuantDTypeCast(const_cast<schema::Primitive *>(src_prim));
|
||||
return new QuantDTypeCast(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_EmbeddingLookup:
|
||||
return new lite::EmbeddingLookup(const_cast<schema::Primitive *>(src_prim));
|
||||
return new EmbeddingLookup(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Elu:
|
||||
return new lite::Elu(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Elu(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_DeDepthwiseConv2D:
|
||||
return new lite::DeconvDepthwiseConv2D(const_cast<schema::Primitive *>(src_prim));
|
||||
return new DeDepthwiseConv2D(const_cast<schema::Primitive *>(src_prim));
|
||||
case schema::PrimitiveType_Shape:
|
||||
return new lite::Shape(const_cast<schema::Primitive *>(src_prim));
|
||||
return new Shape(const_cast<schema::Primitive *>(src_prim));
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -334,9 +426,9 @@ Model *Model::Import(const char *model_buf, size_t size) {
|
|||
|
||||
Model::~Model() { delete (this->model_impl_); }
|
||||
|
||||
lite::Primitive *Model::GetOp(const std::string &name) const {
|
||||
mindspore::lite::PrimitiveC *Model::GetOp(const std::string &name) const {
|
||||
MS_EXCEPTION_IF_NULL(model_impl_);
|
||||
return const_cast<Primitive *>(model_impl_->GetOp(name));
|
||||
return const_cast<PrimitiveC *>(model_impl_->GetOp(name));
|
||||
}
|
||||
|
||||
void Model::FreeMetaGraph() {
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
file(GLOB_RECURSE OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc)
|
||||
file(GLOB_RECURSE C_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc)
|
||||
|
||||
add_library(ops_mid_ OBJECT ${OPS_SRC})
|
||||
add_library(c_ops_mid OBJECT ${C_OPS_SRC})
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the License);
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
|
@ -18,7 +18,7 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "c_ops/arithmetic_self.h"
|
||||
#include "src/ops/arithmetic_self.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "schema/inner/model_generated.h"
|
||||
#else
|
||||
|
@ -29,14 +29,11 @@
|
|||
#define LITE_MINDSPORE_LITE_C_OPS_ABS_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Abs : public ArithmeticSelf {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
explicit Abs(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
|
||||
#else
|
||||
explicit Abs(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
|
||||
#endif
|
||||
explicit Abs(OriginPrimitive *primitive) : ArithmeticSelf(primitive) {}
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_ABS_H_
|
|
@ -14,9 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/activation.h"
|
||||
#include "src/ops/activation.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Activation::GetType() const { return this->primitive->value.AsActivation()->type; }
|
||||
float Activation::GetAlpha() const { return this->primitive->value.AsActivation()->alpha; }
|
||||
|
@ -32,4 +33,5 @@ float Activation::GetAlpha() const { return this->primitive->value_as_Activation
|
|||
void Activation::SetType(int type) {}
|
||||
void Activation::SetAlpha(float alpha) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -18,7 +18,7 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "mindspore/lite/c_ops/primitive_c.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "schema/inner/model_generated.h"
|
||||
#else
|
||||
|
@ -29,18 +29,16 @@
|
|||
#define LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Activation : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
explicit Activation(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#else
|
||||
explicit Activation(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
explicit Activation(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
|
||||
int GetType() const;
|
||||
float GetAlpha() const;
|
||||
void SetType(int type);
|
||||
void SetAlpha(float alpha);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_H_
|
|
@ -14,9 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/activation_grad.h"
|
||||
#include "src/ops/activation_grad.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int ActivationGrad::GetType() const { return this->primitive->value.AsActivationGrad()->type; }
|
||||
|
||||
|
@ -30,4 +31,5 @@ int ActivationGrad::GetType() const { return this->primitive->value_as_Activatio
|
|||
|
||||
void ActivationGrad::SetType(int type) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -18,7 +18,7 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "mindspore/lite/c_ops/primitive_c.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "schema/inner/model_generated.h"
|
||||
#else
|
||||
|
@ -29,16 +29,14 @@
|
|||
#define LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class ActivationGrad : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
explicit ActivationGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#else
|
||||
explicit ActivationGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
explicit ActivationGrad(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
int GetType() const;
|
||||
void SetType(int type);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_
|
|
@ -14,9 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/add.h"
|
||||
#include "src/ops/add.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Add::GetActivationType() const { return this->primitive->value.AsAdd()->activationType; }
|
||||
|
||||
|
@ -30,4 +31,5 @@ int Add::GetActivationType() const { return this->primitive->value_as_Add()->act
|
|||
|
||||
void Add::SetActivationType(int activation_type) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -18,7 +18,7 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "c_ops/arithmetic.h"
|
||||
#include "src/ops/arithmetic.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "schema/inner/model_generated.h"
|
||||
#else
|
||||
|
@ -28,16 +28,15 @@
|
|||
#define LITE_MINDSPORE_LITE_C_OPS_ADD_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Add : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
explicit Add(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#else
|
||||
explicit Add(schema::Primitive *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
explicit Add(OriginPrimitive *primitive) : Arithmetic(primitive) {}
|
||||
|
||||
int GetActivationType() const;
|
||||
void SetActivationType(int activation_type);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_ADD_H_
|
|
@ -14,12 +14,22 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/ops.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "src/ir/tensor.h"
|
||||
#include "src/ops/addn.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int AddN::GetN() const { return this->primitive->value.AsAddN()->N; }
|
||||
|
||||
void AddN::SetN(int n) { this->primitive->value.AsAddN()->N = n; }
|
||||
|
||||
#else
|
||||
|
||||
int AddN::GetN() const { return this->primitive->value_as_AddN()->N(); }
|
||||
|
||||
void AddN::SetN(int n) {}
|
||||
#endif
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
constexpr int kLeastInputNum = 2;
|
||||
}
|
||||
|
@ -48,5 +58,5 @@ int AddN::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::T
|
|||
output->set_data_type(input->data_type());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "mindspore/lite/c_ops/primitive_c.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "schema/inner/model_generated.h"
|
||||
#else
|
||||
|
@ -29,17 +29,15 @@
|
|||
#define LITE_MINDSPORE_LITE_C_OPS_ADD_N_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class AddN : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
explicit AddN(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#else
|
||||
explicit AddN(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
explicit AddN(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetN() const;
|
||||
void SetN(int n);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_ADD_N_H_
|
|
@ -14,12 +14,38 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/ops.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "src/ir/tensor.h"
|
||||
#include "src/ops/argmax.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int ArgMax::GetAxis() const { return this->primitive->value.AsArgMax()->axis; }
|
||||
bool ArgMax::GetOutMaxValue() const { return this->primitive->value.AsArgMax()->outMaxValue; }
|
||||
int ArgMax::GetTopK() const { return this->primitive->value.AsArgMax()->topK; }
|
||||
bool ArgMax::GetKeepDims() const { return this->primitive->value.AsArgMax()->keepDims; }
|
||||
int ArgMax::GetAxisType() const { return this->primitive->value.AsArgMax()->axisType; }
|
||||
|
||||
void ArgMax::SetAxis(int axis) { this->primitive->value.AsArgMax()->axis = axis; }
|
||||
void ArgMax::SetOutMaxValue(bool out_max_value) { this->primitive->value.AsArgMax()->outMaxValue = out_max_value; }
|
||||
void ArgMax::SetTopK(int top_k) { this->primitive->value.AsArgMax()->topK = top_k; }
|
||||
void ArgMax::SetKeepDims(bool keep_dims) { this->primitive->value.AsArgMax()->keepDims = keep_dims; }
|
||||
void ArgMax::SetAxisType(int axis_type) { this->primitive->value.AsArgMax()->axisType = axis_type; }
|
||||
|
||||
#else
|
||||
|
||||
int ArgMax::GetAxis() const { return this->primitive->value_as_ArgMax()->axis(); }
|
||||
bool ArgMax::GetOutMaxValue() const { return this->primitive->value_as_ArgMax()->outMaxValue(); }
|
||||
int ArgMax::GetTopK() const { return this->primitive->value_as_ArgMax()->topK(); }
|
||||
bool ArgMax::GetKeepDims() const { return this->primitive->value_as_ArgMax()->keepDims(); }
|
||||
int ArgMax::GetAxisType() const { return this->primitive->value_as_ArgMax()->axisType(); }
|
||||
|
||||
void ArgMax::SetAxis(int axis) {}
|
||||
void ArgMax::SetOutMaxValue(bool out_max_value) {}
|
||||
void ArgMax::SetTopK(int top_k) {}
|
||||
void ArgMax::SetKeepDims(bool keep_dims) {}
|
||||
void ArgMax::SetAxisType(int axis_type) {}
|
||||
#endif
|
||||
|
||||
namespace mindspore::lite {
|
||||
int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
|
@ -30,7 +56,6 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
MS_LOG(ERROR) << "tensor number is error.";
|
||||
}
|
||||
auto argmax_prim = this->primitive->value_as_ArgMax();
|
||||
|
||||
std::vector<int> output_shape(input->shape());
|
||||
auto input_shape_size = input->shape().size();
|
||||
int axis = argmax_prim->axis() < 0 ? argmax_prim->axis() + input_shape_size : argmax_prim->axis();
|
||||
|
@ -43,11 +68,10 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
} else {
|
||||
output_shape[axis] = argmax_prim->topK();
|
||||
}
|
||||
|
||||
output->SetFormat(input->GetFormat());
|
||||
output->set_shape(output_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "mindspore/lite/c_ops/primitive_c.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "schema/inner/model_generated.h"
|
||||
#else
|
||||
|
@ -29,13 +29,11 @@
|
|||
#define LITE_MINDSPORE_LITE_C_OPS_ARG_MAX_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class ArgMax : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
explicit ArgMax(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#else
|
||||
explicit ArgMax(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
explicit ArgMax(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetAxis() const;
|
||||
bool GetOutMaxValue() const;
|
||||
|
@ -48,6 +46,7 @@ class ArgMax : public PrimitiveC {
|
|||
void SetKeepDims(bool keep_dims);
|
||||
void SetAxisType(int axis_type);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_ARG_MAX_H_
|
|
@ -14,12 +14,38 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/ops.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "src/ir/tensor.h"
|
||||
#include "src/ops/argmin.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int ArgMin::GetAxis() const { return this->primitive->value.AsArgMin()->axis; }
|
||||
bool ArgMin::GetOutMaxValue() const { return this->primitive->value.AsArgMin()->outMaxValue; }
|
||||
int ArgMin::GetTopK() const { return this->primitive->value.AsArgMin()->topK; }
|
||||
bool ArgMin::GetKeepDims() const { return this->primitive->value.AsArgMin()->keepDims; }
|
||||
int ArgMin::GetAxisType() const { return this->primitive->value.AsArgMin()->axisType; }
|
||||
|
||||
void ArgMin::SetAxis(int axis) { this->primitive->value.AsArgMin()->axis = axis; }
|
||||
void ArgMin::SetOutMaxValue(bool out_max_value) { this->primitive->value.AsArgMin()->outMaxValue = out_max_value; }
|
||||
void ArgMin::SetTopK(int top_k) { this->primitive->value.AsArgMin()->topK = top_k; }
|
||||
void ArgMin::SetKeepDims(bool keep_dims) { this->primitive->value.AsArgMin()->keepDims = keep_dims; }
|
||||
void ArgMin::SetAxisType(int axis_type) { this->primitive->value.AsArgMin()->axisType = axis_type; }
|
||||
|
||||
#else
|
||||
|
||||
int ArgMin::GetAxis() const { return this->primitive->value_as_ArgMin()->axis(); }
|
||||
bool ArgMin::GetOutMaxValue() const { return this->primitive->value_as_ArgMin()->outMaxValue(); }
|
||||
int ArgMin::GetTopK() const { return this->primitive->value_as_ArgMin()->topK(); }
|
||||
bool ArgMin::GetKeepDims() const { return this->primitive->value_as_ArgMin()->keepDims(); }
|
||||
int ArgMin::GetAxisType() const { return this->primitive->value_as_ArgMin()->axisType(); }
|
||||
|
||||
void ArgMin::SetAxis(int axis) {}
|
||||
void ArgMin::SetOutMaxValue(bool out_max_value) {}
|
||||
void ArgMin::SetTopK(int top_k) {}
|
||||
void ArgMin::SetKeepDims(bool keep_dims) {}
|
||||
void ArgMin::SetAxisType(int axis_type) {}
|
||||
#endif
|
||||
|
||||
namespace mindspore::lite {
|
||||
int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
|
@ -42,10 +68,10 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
} else {
|
||||
output_shape[axis] = argmin_prim->topK();
|
||||
}
|
||||
|
||||
output->SetFormat(input->GetFormat());
|
||||
output->set_shape(output_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "mindspore/lite/c_ops/primitive_c.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "schema/inner/model_generated.h"
|
||||
#else
|
||||
|
@ -29,13 +29,11 @@
|
|||
#define LITE_MINDSPORE_LITE_C_OPS_ARG_MIN_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class ArgMin : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
explicit ArgMin(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#else
|
||||
explicit ArgMin(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
explicit ArgMin(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetAxis() const;
|
||||
bool GetOutMaxValue() const;
|
||||
|
@ -48,6 +46,7 @@ class ArgMin : public PrimitiveC {
|
|||
void SetKeepDims(bool keep_dims);
|
||||
void SetAxisType(int axis_type);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_ARG_MIN_H_
|
|
@ -14,13 +14,14 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/ops.h"
|
||||
#include "src/ops/arithmetic.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "src/ir/tensor.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
int Arithmetic::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (inputs_.size() != kDoubleNum) {
|
||||
MS_LOG(ERROR) << "The number of input must be " << kDoubleNum;
|
||||
|
@ -103,5 +104,5 @@ int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
|
|||
output->set_shape(output_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "mindspore/lite/c_ops/primitive_c.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "schema/inner/model_generated.h"
|
||||
#else
|
||||
|
@ -29,13 +29,11 @@
|
|||
#define LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Arithmetic : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
explicit Arithmetic(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#else
|
||||
explicit Arithmetic(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
explicit Arithmetic(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
bool Broadcasting() { return this->broadcasting_; }
|
||||
int NDims() { return this->ndim_; }
|
||||
|
@ -50,6 +48,7 @@ class Arithmetic : public PrimitiveC {
|
|||
std::vector<int> in_shape1_;
|
||||
std::vector<int> out_shape_;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,12 +14,13 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/ops.h"
|
||||
#include "src/ops/arithmetic_self.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "src/ir/tensor.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
int ArithmeticSelf::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
|
@ -32,7 +33,7 @@ int ArithmeticSelf::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
|
|||
return RET_OK;
|
||||
}
|
||||
output->set_shape(input->shape());
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "mindspore/lite/c_ops/primitive_c.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "schema/inner/model_generated.h"
|
||||
#else
|
||||
|
@ -28,15 +28,14 @@
|
|||
#define LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_SELF_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class ArithmeticSelf : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
explicit ArithmeticSelf(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#else
|
||||
explicit ArithmeticSelf(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
explicit ArithmeticSelf(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_SELF_H_
|
|
@ -14,9 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/batch_norm.h"
|
||||
#include "src/ops/batch_norm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
float BatchNorm::GetEpsilon() const { return this->primitive->value.AsBatchNorm()->epsilon; }
|
||||
|
||||
|
@ -28,4 +29,5 @@ float BatchNorm::GetEpsilon() const { return this->primitive->value_as_BatchNorm
|
|||
|
||||
void BatchNorm::SetEpsilon(float epsilon) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -18,7 +18,7 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "mindspore/lite/c_ops/primitive_c.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "schema/inner/model_generated.h"
|
||||
#else
|
||||
|
@ -29,16 +29,15 @@
|
|||
#define LITE_MINDSPORE_LITE_C_OPS_BATCH_NORM_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class BatchNorm : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
explicit BatchNorm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#else
|
||||
explicit BatchNorm(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
explicit BatchNorm(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
float GetEpsilon() const;
|
||||
void SetEpsilon(float epsilon);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_BATCH_NORM_H_
|
|
@ -14,12 +14,37 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/ops.h"
|
||||
#include "src/ops/batch_to_space.h"
|
||||
#include "src/common/common.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "src/ir/tensor.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
std::vector<int> BatchToSpace::GetBlockShape() const { return this->primitive->value.AsBatchToSpace()->blockShape; }
|
||||
std::vector<int> BatchToSpace::GetCrops() const { return this->primitive->value.AsBatchToSpace()->crops; }
|
||||
|
||||
void BatchToSpace::SetBlockShape(const std::vector<int> &block_shape) {
|
||||
this->primitive->value.AsBatchToSpace()->blockShape = block_shape;
|
||||
}
|
||||
void BatchToSpace::SetCrops(const std::vector<int> &crops) { this->primitive->value.AsBatchToSpace()->crops = crops; }
|
||||
|
||||
#else
|
||||
|
||||
std::vector<int> BatchToSpace::GetBlockShape() const {
|
||||
auto fb_vector = this->primitive->value_as_BatchToSpace()->blockShape();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
std::vector<int> BatchToSpace::GetCrops() const {
|
||||
auto fb_vector = this->primitive->value_as_BatchToSpace()->crops();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
void BatchToSpace::SetBlockShape(const std::vector<int> &block_shape) {}
|
||||
void BatchToSpace::SetCrops(const std::vector<int> &crops) {}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr int kBatchToSpaceOutputNum = 1;
|
||||
constexpr int kBatchToSpaceInputNum = 1;
|
||||
|
@ -27,7 +52,7 @@ constexpr int kBlockShapeSize = 2;
|
|||
constexpr int kCropsSize = 4;
|
||||
} // namespace
|
||||
|
||||
int BatchToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) {
|
||||
int BatchToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (outputs.size() != kBatchToSpaceOutputNum || inputs.size() != kBatchToSpaceInputNum) {
|
||||
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
|
||||
|
@ -49,49 +74,50 @@ int BatchToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
|
|||
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto prim = this->primitive->value_as_BatchToSpace();
|
||||
auto block_shape = prim->blockShape();
|
||||
if (block_shape->size() != kBlockShapeSize) {
|
||||
|
||||
auto block_shape = GetBlockShape();
|
||||
if (block_shape.size() != kBlockShapeSize) {
|
||||
MS_LOG(ERROR) << "Block shape size should be " << kBlockShapeSize;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto crops = prim->crops();
|
||||
if (crops->size() != kCropsSize) {
|
||||
auto crops = GetCrops();
|
||||
if (crops.size() != kCropsSize) {
|
||||
MS_LOG(ERROR) << "Crops size should be " << kCropsSize;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
size_t mul_block_shape = 1;
|
||||
|
||||
for (size_t i = 0; i < kBlockShapeSize; ++i) {
|
||||
if (block_shape->Get(i) <= 0) {
|
||||
if (block_shape[i] <= 0) {
|
||||
MS_LOG(ERROR) << "Input block_shape should > 0!";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
if (input_shape[kNHWC_n_index] % block_shape->Get(i)) {
|
||||
MS_LOG(ERROR) << "Dimension n " << input_shape[kNHWC_n_index] << " can not divide block_shape[" << i << "] "
|
||||
<< block_shape->Get(i);
|
||||
return RET_PARAM_INVALID;
|
||||
if (input_shape[NHWC_N] % block_shape[i]) {
|
||||
MS_LOG(ERROR) << "Dimension n " << input_shape[NHWC_N] << " can not divide block_shape[" << i << "] "
|
||||
<< block_shape[i];
|
||||
return 1;
|
||||
}
|
||||
mul_block_shape *= block_shape->Get(i);
|
||||
mul_block_shape *= block_shape[i];
|
||||
}
|
||||
|
||||
if (input_shape[kNHWC_n_index] < mul_block_shape) {
|
||||
MS_LOG(ERROR) << "Dimension n " << input_shape[kNHWC_n_index] << " < product of block shape!";
|
||||
if (input_shape[NHWC_N] < mul_block_shape) {
|
||||
MS_LOG(ERROR) << "Dimension n " << input_shape[NHWC_N] << " < product of block shape!";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
for (size_t i = 0; i < kCropsSize; ++i) {
|
||||
if (crops->Get(i) < 0) {
|
||||
if (crops[i] < 0) {
|
||||
MS_LOG(ERROR) << "Input crops should >= 0";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
std::vector<int32_t> output_shape(input_shape.size());
|
||||
output_shape[kNHWC_n_index] = input_shape[kNHWC_n_index] / mul_block_shape;
|
||||
output_shape[kNHWC_h_index] = input_shape[kNHWC_h_index] * block_shape->Get(0) - crops->Get(0) - crops->Get(1);
|
||||
output_shape[kNHWC_w_index] = input_shape[kNHWC_w_index] * block_shape->Get(1) - crops->Get(2) - crops->Get(3);
|
||||
output_shape[kNHWC_c_index] = input_shape[kNHWC_c_index];
|
||||
output_shape[NHWC_N] = input_shape[NHWC_N] / mul_block_shape;
|
||||
output_shape[NHWC_H] = input_shape[NHWC_H] * block_shape[0] - crops[0] - crops[1];
|
||||
output_shape[NHWC_W] = input_shape[NHWC_W] * block_shape[1] - crops[2] - crops[3];
|
||||
output_shape[NHWC_C] = input_shape[NHWC_C];
|
||||
|
||||
outputs[0]->set_shape(output_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "mindspore/lite/c_ops/primitive_c.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "schema/inner/model_generated.h"
|
||||
#else
|
||||
|
@ -29,19 +29,18 @@
|
|||
#define LITE_MINDSPORE_LITE_C_OPS_BATCH_TO_SPACE_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class BatchToSpace : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
explicit BatchToSpace(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#else
|
||||
explicit BatchToSpace(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
explicit BatchToSpace(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
std::vector<int> GetBlockShape() const;
|
||||
std::vector<int> GetCrops() const;
|
||||
void SetBlockShape(const std::vector<int> &block_shape);
|
||||
void SetCrops(const std::vector<int> &crops);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_BATCH_TO_SPACE_H_
|
|
@ -14,9 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/bias_add.h"
|
||||
#include "src/ops/bias_add.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
std::vector<int> BiasAdd::GetAxis() const { return this->primitive->value.AsBiasAdd()->axis; }
|
||||
|
||||
|
@ -31,4 +32,5 @@ std::vector<int> BiasAdd::GetAxis() const {
|
|||
|
||||
void BiasAdd::SetAxis(const std::vector<int> &axis) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -18,7 +18,7 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "mindspore/lite/c_ops/primitive_c.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "schema/inner/model_generated.h"
|
||||
#else
|
||||
|
@ -29,16 +29,15 @@
|
|||
#define LITE_MINDSPORE_LITE_C_OPS_BIAS_ADD_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class BiasAdd : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
explicit BiasAdd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#else
|
||||
explicit BiasAdd(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
explicit BiasAdd(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
std::vector<int> GetAxis() const;
|
||||
void SetAxis(const std::vector<int> &axis);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_BIAS_ADD_H_
|
|
@ -14,9 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/bias_grad.h"
|
||||
#include "src/ops/bias_grad.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
std::vector<int> BiasGrad::GetAxis() const { return this->primitive->value.AsBiasGrad()->axis; }
|
||||
|
||||
|
@ -31,4 +32,5 @@ std::vector<int> BiasGrad::GetAxis() const {
|
|||
|
||||
void BiasGrad::SetAxis(const std::vector<int> &axis) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -18,7 +18,7 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "mindspore/lite/c_ops/primitive_c.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "schema/inner/model_generated.h"
|
||||
#else
|
||||
|
@ -29,16 +29,15 @@
|
|||
#define LITE_MINDSPORE_LITE_C_OPS_BIAS_GRAD_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class BiasGrad : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
explicit BiasGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#else
|
||||
explicit BiasGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
explicit BiasGrad(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
std::vector<int> GetAxis() const;
|
||||
void SetAxis(const std::vector<int> &axis);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_BIAS_GRAD_H_
|
|
@ -14,9 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/bn_grad_input.h"
|
||||
#include "src/ops/bn_grad_input.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
float BNGradInput::GetEps() const { return this->primitive->value.AsBNGradInput()->eps; }
|
||||
int BNGradInput::GetChannels() const { return this->primitive->value.AsBNGradInput()->channels; }
|
||||
|
@ -32,4 +33,5 @@ int BNGradInput::GetChannels() const { return this->primitive->value_as_BNGradIn
|
|||
void BNGradInput::SetEps(float eps) {}
|
||||
void BNGradInput::SetChannels(int channels) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -18,7 +18,7 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "mindspore/lite/c_ops/primitive_c.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "schema/inner/model_generated.h"
|
||||
#else
|
||||
|
@ -29,18 +29,17 @@
|
|||
#define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class BNGradInput : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
explicit BNGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#else
|
||||
explicit BNGradInput(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
explicit BNGradInput(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
float GetEps() const;
|
||||
int GetChannels() const;
|
||||
void SetEps(float eps);
|
||||
void SetChannels(int channels);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_
|
|
@ -14,22 +14,36 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/ops.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "src/ir/tensor.h"
|
||||
#include "src/ops/broadcast_to.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
std::vector<int> BroadcastTo::GetDstShape() const { return this->primitive->value.AsBroadcastTo()->dst_shape; }
|
||||
|
||||
void BroadcastTo::SetDstShape(const std::vector<int> &dst_shape) {
|
||||
this->primitive->value.AsBroadcastTo()->dst_shape = dst_shape;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
std::vector<int> BroadcastTo::GetDstShape() const {
|
||||
auto fb_vector = this->primitive->value_as_BroadcastTo()->dst_shape();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
void BroadcastTo::SetDstShape(const std::vector<int> &dst_shape) {}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr int kBroadcastToInputNum = 1;
|
||||
constexpr int kBroadcastToOutputNum = 1;
|
||||
} // namespace
|
||||
|
||||
int BroadcastTo::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) {
|
||||
int BroadcastTo::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (inputs.size() != kBroadcastToInputNum || outputs.size() != kBroadcastToOutputNum) {
|
||||
MS_LOG(ERROR) << "input size:" << inputs.size() << ", output size:" << outputs.size();
|
||||
return RET_PARAM_INVALID;
|
||||
return 1;
|
||||
}
|
||||
auto input = inputs.at(0);
|
||||
std::vector<int32_t> dst_shape(this->primitive->value_as_BroadcastTo()->dst_shape()->begin(),
|
||||
|
@ -40,19 +54,19 @@ int BroadcastTo::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<te
|
|||
if (input_shape.size() > dst_shape.size()) {
|
||||
MS_LOG(ERROR) << "input shape size " << input_shape.size() << " should <= broadcast to shape size "
|
||||
<< dst_shape.size() << "!";
|
||||
return RET_PARAM_INVALID;
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (int i = dst_shape.size() - 1; i >= 0; --i) {
|
||||
if (dst_shape[i] < 0) {
|
||||
MS_LOG(ERROR) << "shape[" << i << "] = " << dst_shape[i] << " ] should be > 0!";
|
||||
return RET_PARAM_INVALID;
|
||||
return 1;
|
||||
}
|
||||
if (input_shape_index >= 0) {
|
||||
auto dim = input_shape[input_shape_index];
|
||||
if (dim != dst_shape[i] && dim != 1) {
|
||||
MS_LOG(ERROR) << "Invalid broadcast shape!";
|
||||
return RET_PARAM_INVALID;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
shape[i] = dst_shape[i];
|
||||
|
@ -61,6 +75,7 @@ int BroadcastTo::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<te
|
|||
outputs[0]->SetFormat(input->GetFormat());
|
||||
outputs[0]->set_shape(shape);
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
return RET_OK;
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "mindspore/lite/c_ops/primitive_c.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "schema/inner/model_generated.h"
|
||||
#else
|
||||
|
@ -29,17 +29,16 @@
|
|||
#define LITE_MINDSPORE_LITE_C_OPS_BROADCAST_TO_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class BroadcastTo : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
explicit BroadcastTo(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#else
|
||||
explicit BroadcastTo(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
explicit BroadcastTo(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
std::vector<int> GetDstShape() const;
|
||||
void SetDstShape(const std::vector<int> &dst_shape);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_BROADCAST_TO_H_
|
|
@ -14,9 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/caffe_p_relu.h"
|
||||
#include "src/ops/caffe_p_relu.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
bool CaffePReLU::GetChannelShared() const { return this->primitive->value.AsCaffePReLU()->channelShared; }
|
||||
|
||||
|
@ -30,4 +31,5 @@ bool CaffePReLU::GetChannelShared() const { return this->primitive->value_as_Caf
|
|||
|
||||
void CaffePReLU::SetChannelShared(bool channel_shared) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -18,8 +18,8 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "mindspore/lite/c_ops/primitive_c.h"
|
||||
#include "c_ops/activation.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "src/ops/activation.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "schema/inner/model_generated.h"
|
||||
#else
|
||||
|
@ -30,16 +30,15 @@
|
|||
#define LITE_MINDSPORE_LITE_C_OPS_CAFFE_P_RE_L_U_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class CaffePReLU : public Activation {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
explicit CaffePReLU(schema::PrimitiveT *primitive) : Activation(primitive) {}
|
||||
#else
|
||||
explicit CaffePReLU(schema::Primitive *primitive) : Activation(primitive) {}
|
||||
#endif
|
||||
explicit CaffePReLU(OriginPrimitive *primitive) : Activation(primitive) {}
|
||||
|
||||
bool GetChannelShared() const;
|
||||
void SetChannelShared(bool channel_shared);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_CAFFE_P_RE_L_U_H_
|
|
@ -14,12 +14,26 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/ops.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "src/ir/tensor.h"
|
||||
#include "src/ops/cast.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Cast::GetSrcT() const { return this->primitive->value.AsCast()->srcT; }
|
||||
int Cast::GetDstT() const { return this->primitive->value.AsCast()->dstT; }
|
||||
|
||||
void Cast::SetSrcT(int src_t) { this->primitive->value.AsCast()->srcT = src_t; }
|
||||
void Cast::SetDstT(int dst_t) { this->primitive->value.AsCast()->dstT = dst_t; }
|
||||
|
||||
#else
|
||||
|
||||
int Cast::GetSrcT() const { return this->primitive->value_as_Cast()->srcT(); }
|
||||
int Cast::GetDstT() const { return this->primitive->value_as_Cast()->dstT(); }
|
||||
|
||||
void Cast::SetSrcT(int src_t) {}
|
||||
void Cast::SetDstT(int dst_t) {}
|
||||
#endif
|
||||
|
||||
namespace mindspore::lite {
|
||||
int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
|
@ -49,4 +63,5 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
|
|||
output->set_data_type(TypeId::kNumberTypeFloat32);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "mindspore/lite/c_ops/primitive_c.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "schema/inner/model_generated.h"
|
||||
#else
|
||||
|
@ -29,19 +29,18 @@
|
|||
#define LITE_MINDSPORE_LITE_C_OPS_CAST_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Cast : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
explicit Cast(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#else
|
||||
explicit Cast(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
explicit Cast(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetSrcT() const;
|
||||
int GetDstT() const;
|
||||
void SetSrcT(int src_t);
|
||||
void SetDstT(int dst_t);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_CAST_H_
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "src/ops/arithmetic_self.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "schema/inner/model_generated.h"
|
||||
#else
|
||||
#include "schema/model_generated.h"
|
||||
#endif
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_CEIL_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_CEIL_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Ceil : public ArithmeticSelf {
|
||||
public:
|
||||
explicit Ceil(OriginPrimitive *primitive) : ArithmeticSelf(primitive) {}
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_CEIL_H_
|
|
@ -14,9 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/clip.h"
|
||||
#include "src/ops/clip.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
float Clip::GetMax() const { return this->primitive->value.AsClip()->max; }
|
||||
float Clip::GetMin() const { return this->primitive->value.AsClip()->min; }
|
||||
|
@ -32,4 +33,5 @@ float Clip::GetMin() const { return this->primitive->value_as_Clip()->min(); }
|
|||
void Clip::SetMax(float max) {}
|
||||
void Clip::SetMin(float min) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -18,7 +18,7 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "mindspore/lite/c_ops/primitive_c.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "schema/inner/model_generated.h"
|
||||
#else
|
||||
|
@ -29,18 +29,17 @@
|
|||
#define LITE_MINDSPORE_LITE_C_OPS_CLIP_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Clip : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
explicit Clip(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#else
|
||||
explicit Clip(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
explicit Clip(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
float GetMax() const;
|
||||
float GetMin() const;
|
||||
void SetMax(float max);
|
||||
void SetMin(float min);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_CLIP_H_
|
|
@ -14,12 +14,28 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/ops.h"
|
||||
#include "src/ops/concat.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "src/ir/tensor.h"
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Concat::GetAxis() const { return this->primitive->value.AsConcat()->axis; }
|
||||
int Concat::GetN() const { return this->primitive->value.AsConcat()->n; }
|
||||
|
||||
void Concat::SetAxis(int axis) { this->primitive->value.AsConcat()->axis = axis; }
|
||||
void Concat::SetN(int n) { this->primitive->value.AsConcat()->n = n; }
|
||||
|
||||
#else
|
||||
|
||||
int Concat::GetAxis() const { return this->primitive->value_as_Concat()->axis(); }
|
||||
int Concat::GetN() const { return this->primitive->value_as_Concat()->n(); }
|
||||
|
||||
void Concat::SetAxis(int axis) {}
|
||||
void Concat::SetN(int n) {}
|
||||
#endif
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
constexpr int kConcatOutputNum = 1;
|
||||
}
|
||||
|
@ -47,7 +63,6 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
MS_LOG(ERROR) << "Invalid axis: " << axis;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto input0_shape_without_axis = input0_shape;
|
||||
input0_shape_without_axis.erase(input0_shape_without_axis.begin() + axis);
|
||||
auto input0_data_type = inputs_.at(0)->data_type();
|
||||
|
@ -58,7 +73,6 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
MS_LOG(ERROR) << "All inputs should have the same data type!";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (inputs_.at(i)->GetFormat() != input0_format) {
|
||||
MS_LOG(ERROR) << "All input format should be the same!";
|
||||
return RET_PARAM_INVALID;
|
||||
|
@ -81,4 +95,5 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
outputs_[0]->set_shape(output_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue