remove unuseful api

This commit is contained in:
zhaizhiqiang 2022-04-02 14:39:37 +08:00
parent 46a614d37a
commit e588be4291
6 changed files with 53 additions and 268 deletions

View File

@ -46,47 +46,6 @@ class MS_API Cell : public CellBase {
std::shared_ptr<CellBase> Clone() const override { return std::make_shared<T>(static_cast<const T &>(*this)); }
};
class MS_API ParameterCell final : public Cell<ParameterCell> {
public:
ParameterCell() = default;
~ParameterCell() override = default;
ParameterCell(const ParameterCell &);
ParameterCell &operator=(const ParameterCell &);
ParameterCell(ParameterCell &&);
ParameterCell &operator=(ParameterCell &&);
explicit ParameterCell(const MSTensor &);
ParameterCell &operator=(const MSTensor &);
explicit ParameterCell(MSTensor &&);
ParameterCell &operator=(MSTensor &&);
MSTensor GetTensor() const { return tensor_; }
private:
MSTensor tensor_;
};
class MS_API OpCellBase : public CellBase {
public:
explicit OpCellBase(const std::string &name) : name_(name) {}
~OpCellBase() override = default;
const std::string &GetOpType() const { return name_; }
protected:
std::string name_;
};
template <class T>
class MS_API OpCell : public OpCellBase, public std::enable_shared_from_this<T> {
public:
explicit OpCell(const std::string &name) : OpCellBase(name) {}
~OpCell() override = default;
std::shared_ptr<CellBase> Clone() const override { return std::make_shared<T>(static_cast<const T &>(*this)); }
};
class MS_API GraphCell final : public Cell<GraphCell> {
public:
class GraphImpl;
@ -117,10 +76,6 @@ class MS_API InputAndOutput {
InputAndOutput();
~InputAndOutput() = default;
// no explicit
InputAndOutput(const MSTensor &); // NOLINT(runtime/explicit)
InputAndOutput(MSTensor &&); // NOLINT(runtime/explicit)
InputAndOutput(const std::shared_ptr<CellBase> &, const std::vector<InputAndOutput> &, int32_t index);
int32_t GetIndex() const { return index_; }

View File

@ -45,6 +45,59 @@ class MS_API Model {
Model(const Model &) = delete;
void operator=(const Model &) = delete;
/// \brief Build a model from model buffer so that it can run on a device. Only valid for Lite.
///
/// \param[in] model_data Define the buffer read from a model file.
/// \param[in] data_size Define bytes number of model buffer.
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
/// ModelType::kMindIR is valid for Lite.
/// \param[in] model_context Define the context used to store options during execution.
///
/// \return Status.
Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr);
/// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite.
///
/// \param[in] model_path Define the model path.
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
/// ModelType::kMindIR is valid for Lite.
/// \param[in] model_context Define the context used to store options during execution.
///
/// \return Status.
Status Build(const std::string &model_path, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr);
/// \brief Build a model from model buffer so that it can run on a device. Only valid for Lite.
///
/// \param[in] model_data Define the buffer read from a model file.
/// \param[in] data_size Define bytes number of model buffer.
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
/// ModelType::kMindIR is valid for Lite.
/// \param[in] model_context Define the context used to store options during execution.
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16.
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM.
/// \param[in] cropto_lib_path Define the openssl library path.
///
/// \return Status.
Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode,
const std::string &cropto_lib_path);
/// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite.
///
/// \param[in] model_path Define the model path.
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
/// ModelType::kMindIR is valid for Lite.
/// \param[in] model_context Define the context used to store options during execution.
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16.
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM.
/// \param[in] cropto_lib_path Define the openssl library path.
///
/// \return Status.
Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
const Key &dec_key, const std::string &dec_mode, const std::string &cropto_lib_path);
/// \brief Builds a model
///
/// \param[in] graph GraphCell is a derivative of Cell. Cell is not available currently. GraphCell can be constructed
@ -253,59 +306,6 @@ class MS_API Model {
Status Train(int epochs, std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> cbs);
Status Evaluate(std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> cbs);
/// \brief Build a model from model buffer so that it can run on a device. Only valid for Lite.
///
/// \param[in] model_data Define the buffer read from a model file.
/// \param[in] data_size Define bytes number of model buffer.
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
/// ModelType::kMindIR is valid for Lite.
/// \param[in] model_context Define the context used to store options during execution.
///
/// \return Status.
Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr);
/// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite.
///
/// \param[in] model_path Define the model path.
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
/// ModelType::kMindIR is valid for Lite.
/// \param[in] model_context Define the context used to store options during execution.
///
/// \return Status.
Status Build(const std::string &model_path, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr);
/// \brief Build a model from model buffer so that it can run on a device. Only valid for Lite.
///
/// \param[in] model_data Define the buffer read from a model file.
/// \param[in] data_size Define bytes number of model buffer.
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
/// ModelType::kMindIR is valid for Lite.
/// \param[in] model_context Define the context used to store options during execution.
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16.
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM.
/// \param[in] cropto_lib_path Define the openssl library path.
///
/// \return Status.
Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode,
const std::string &cropto_lib_path);
/// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite.
///
/// \param[in] model_path Define the model path.
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
/// ModelType::kMindIR is valid for Lite.
/// \param[in] model_context Define the context used to store options during execution.
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16.
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM.
/// \param[in] cropto_lib_path Define the openssl library path.
///
/// \return Status.
Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
const Key &dec_key, const std::string &dec_mode, const std::string &cropto_lib_path);
private:
friend class Serialization;
// api without std::string

View File

@ -1,48 +0,0 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_API_OPS_OPS_H
#define MINDSPORE_INCLUDE_API_OPS_OPS_H
#include <string>
#include <vector>
#include <map>
#include <memory>
#include "include/api/status.h"
#include "include/api/types.h"
#include "include/api/cell.h"
namespace mindspore {
struct MS_API Conv2D : public OpCell<Conv2D> {
Conv2D() : OpCell("Conv2D") {}
~Conv2D() override = default;
std::vector<Output> Construct(const std::vector<Input> &inputs) override;
Conv2D(int out_channel, const std::vector<int> &kernel_size, int mode = 1, const std::string &pad_mode = "valid",
const std::vector<int> &pad = {0, 0, 0, 0}, const std::vector<int> &stride = {1, 1, 1, 1},
const std::vector<int> &dilation = {1, 1, 1, 1}, int group = 1);
Output operator()(const Input &, const Input &) const;
int out_channel = 0;
std::vector<int> kernel_size;
int mode = 1;
std::string pad_mode = "valid";
std::vector<int> pad = {0, 0, 0, 0};
std::vector<int> stride = {1, 1, 1, 1};
std::vector<int> dilation = {1, 1, 1, 1};
int group = 1;
};
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_OPS_OPS_H

View File

@ -21,52 +21,6 @@
namespace mindspore {
std::vector<Output> CellBase::operator()(const std::vector<Input> &inputs) const { return Clone()->Construct(inputs); }
ParameterCell::ParameterCell(const ParameterCell &cell) : Cell<ParameterCell>(cell) {
auto tmp_ptr = cell.tensor_.Clone();
tensor_ = *tmp_ptr;
MSTensor::DestroyTensorPtr(tmp_ptr);
}
ParameterCell &ParameterCell::operator=(const ParameterCell &cell) {
if (&cell == this) {
return *this;
}
auto tmp_ptr = cell.tensor_.Clone();
tensor_ = *tmp_ptr;
MSTensor::DestroyTensorPtr(tmp_ptr);
return *this;
}
ParameterCell::ParameterCell(ParameterCell &&cell) : Cell<ParameterCell>(std::move(cell)), tensor_(cell.tensor_) {}
ParameterCell &ParameterCell::operator=(ParameterCell &&cell) {
if (&cell == this) {
return *this;
}
tensor_ = cell.tensor_;
return *this;
}
ParameterCell::ParameterCell(const MSTensor &tensor) {
auto tmp_ptr = tensor.Clone();
tensor_ = *tmp_ptr;
MSTensor::DestroyTensorPtr(tmp_ptr);
}
ParameterCell &ParameterCell::operator=(const MSTensor &tensor) {
auto tmp_ptr = tensor.Clone();
tensor_ = *tmp_ptr;
MSTensor::DestroyTensorPtr(tmp_ptr);
return *this;
}
ParameterCell::ParameterCell(MSTensor &&tensor) : tensor_(tensor) {}
ParameterCell &ParameterCell::operator=(MSTensor &&tensor) {
tensor_ = tensor;
return *this;
}
GraphCell::GraphCell(const Graph &graph) : graph_(std::make_shared<Graph>(graph)) { MS_EXCEPTION_IF_NULL(graph_); }
GraphCell::GraphCell(const std::shared_ptr<Graph> &graph) : graph_(graph) { MS_EXCEPTION_IF_NULL(graph_); }
@ -135,14 +89,6 @@ std::vector<MSTensor> GraphCell::GetOutputs() {
InputAndOutput::InputAndOutput() : cell_(nullptr), prev_(), index_(-1) {}
InputAndOutput::InputAndOutput(const MSTensor &tensor) : prev_(), index_(-1) {
auto tmp_ptr = tensor.Clone();
cell_ = std::make_shared<ParameterCell>(*tmp_ptr);
MSTensor::DestroyTensorPtr(tmp_ptr);
}
InputAndOutput::InputAndOutput(MSTensor &&tensor)
: cell_(std::make_shared<ParameterCell>(tensor)), prev_(), index_(-1) {}
InputAndOutput::InputAndOutput(const std::shared_ptr<CellBase> &cell, const std::vector<InputAndOutput> &prev,
int32_t index)
: cell_(cell), prev_(prev), index_(index) {}

View File

@ -1,38 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/api/ops/ops.h"
namespace mindspore {
Conv2D::Conv2D(int out_channel, const std::vector<int> &kernel_size, int mode, const std::string &pad_mode,
const std::vector<int> &pad, const std::vector<int> &stride, const std::vector<int> &dilation, int group)
: OpCell("Conv2D"),
out_channel(out_channel),
kernel_size(kernel_size),
mode(mode),
pad_mode(pad_mode),
pad(pad),
stride(stride),
dilation(dilation),
group(group) {}
Output Conv2D::operator()(const Input &input1, const Input &input2) const {
return CellBase::operator()({input1, input2})[0];
}
std::vector<Output> Conv2D::Construct(const std::vector<Input> &inputs) {
return {Output(shared_from_this(), inputs, 1)};
}
} // namespace mindspore

View File

@ -24,33 +24,6 @@ std::vector<Output> CellBase::operator()(const std::vector<Input> &inputs) const
return empty;
}
ParameterCell::ParameterCell(const ParameterCell &cell) { MS_LOG(ERROR) << "Unsupported feature."; }
ParameterCell &ParameterCell::operator=(const ParameterCell &cell) {
MS_LOG(ERROR) << "Unsupported feature.";
return *this;
}
ParameterCell::ParameterCell(ParameterCell &&cell) { MS_LOG(ERROR) << "Unsupported feature."; }
ParameterCell &ParameterCell::operator=(ParameterCell &&cell) {
MS_LOG(ERROR) << "Unsupported feature.";
return *this;
}
ParameterCell::ParameterCell(const MSTensor &tensor) { MS_LOG(ERROR) << "Unsupported feature."; }
ParameterCell &ParameterCell::operator=(const MSTensor &tensor) {
MS_LOG(ERROR) << "Unsupported feature.";
return *this;
}
ParameterCell::ParameterCell(MSTensor &&tensor) : tensor_(tensor) { MS_LOG(ERROR) << "Unsupported feature."; }
ParameterCell &ParameterCell::operator=(MSTensor &&tensor) {
MS_LOG(ERROR) << "Unsupported feature.";
return *this;
}
GraphCell::GraphCell(const Graph &graph) : graph_(std::shared_ptr<Graph>(new (std::nothrow) Graph(graph))) {
if (graph_ == nullptr) {
MS_LOG(ERROR) << "Invalid graph.";
@ -81,9 +54,6 @@ Status GraphCell::Load(uint32_t device_id) {
InputAndOutput::InputAndOutput() { MS_LOG(ERROR) << "Unsupported feature."; }
InputAndOutput::InputAndOutput(const MSTensor &tensor) { MS_LOG(ERROR) << "Unsupported feature."; }
InputAndOutput::InputAndOutput(MSTensor &&tensor) { MS_LOG(ERROR) << "Unsupported feature."; }
InputAndOutput::InputAndOutput(const std::shared_ptr<CellBase> &cell, const std::vector<InputAndOutput> &prev,
int32_t index) {
MS_LOG(ERROR) << "Unsupported feature.";