forked from mindspore-Ecosystem/mindspore
parent
7835f73fea
commit
d259c159ba
|
@ -1,23 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "ops/add_fold.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
REGISTER_PRIMITIVE_C(kNameAddFold, AddFold);
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,40 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_ADD_FOLD_H_
|
|
||||||
#define MINDSPORE_CORE_OPS_ADD_FOLD_H_
|
|
||||||
#include <map>
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <memory>
|
|
||||||
#include "ops/primitive_c.h"
|
|
||||||
#include "abstract/abstract_value.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
constexpr auto kNameAddFold = "AddFold";
|
|
||||||
class AddFold : public PrimitiveC {
|
|
||||||
public:
|
|
||||||
AddFold() : PrimitiveC(kNameAddFold) {}
|
|
||||||
~AddFold() = default;
|
|
||||||
MS_DECLARE_PARENT(AddFold, PrimitiveC);
|
|
||||||
void Init() {}
|
|
||||||
};
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // MINDSPORE_CORE_OPS_ADD_FOLD_H_
|
|
|
@ -1,94 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <map>
|
|
||||||
#include <string>
|
|
||||||
#include "ops/batch_norm_fold.h"
|
|
||||||
#include "ops/op_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
void BatchNormFold::Init(const float momentum, const float epsilon, const bool is_training, const int64_t freeze_bn) {
|
|
||||||
set_momentum(momentum);
|
|
||||||
set_epsilon(epsilon);
|
|
||||||
set_is_training(is_training);
|
|
||||||
set_freeze_bn(freeze_bn);
|
|
||||||
}
|
|
||||||
|
|
||||||
void BatchNormFold::set_momentum(const float momentum) {
|
|
||||||
CheckAndConvertUtils::CheckInRange<int64_t>(kMomentum, momentum, kIncludeBoth, {0.0, 1.0}, this->name());
|
|
||||||
this->AddAttr(kMomentum, MakeValue(momentum));
|
|
||||||
}
|
|
||||||
|
|
||||||
float BatchNormFold::get_momentum() const {
|
|
||||||
auto value_ptr = GetAttr(kMomentum);
|
|
||||||
return GetValue<float>(value_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
void BatchNormFold::set_epsilon(const float epsilon) {
|
|
||||||
float match_value = 0.0;
|
|
||||||
CheckAndConvertUtils::CheckValue(kEpsilon, epsilon, kGreaterThan, match_value, this->name());
|
|
||||||
this->AddAttr(kEpsilon, MakeValue(epsilon));
|
|
||||||
}
|
|
||||||
|
|
||||||
float BatchNormFold::get_epsilon() const {
|
|
||||||
auto value_ptr = GetAttr(kEpsilon);
|
|
||||||
return GetValue<float>(value_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
void BatchNormFold::set_is_training(const bool is_training) { this->AddAttr(kIsTraining, MakeValue(is_training)); }
|
|
||||||
|
|
||||||
bool BatchNormFold::get_is_training() const {
|
|
||||||
auto value_ptr = GetAttr(kIsTraining);
|
|
||||||
return GetValue<bool>(value_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
void BatchNormFold::set_freeze_bn(const int64_t freeze_bn) { this->AddAttr(kFreezeBn, MakeValue(freeze_bn)); }
|
|
||||||
|
|
||||||
int64_t BatchNormFold::get_freeze_bn() const {
|
|
||||||
auto value_ptr = GetAttr(kFreezeBn);
|
|
||||||
return GetValue<int64_t>(value_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
AbstractBasePtr BatchNormFoldInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
|
||||||
auto op_name = primitive->name();
|
|
||||||
auto mean_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
|
||||||
auto variance_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
|
||||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
|
||||||
auto global_step_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape];
|
|
||||||
CheckAndConvertUtils::Check("mean_shape", mean_shape, kEqual, "gamma_shape", variance_shape, op_name);
|
|
||||||
CheckAndConvertUtils::Check("mean_shape[0]", mean_shape[0], kEqual, "input channel", x_shape[1], op_name);
|
|
||||||
CheckAndConvertUtils::CheckInteger("global step shape len", global_step_shape.size(), kEqual, 1, op_name);
|
|
||||||
|
|
||||||
auto mean_type = input_args[1]->BuildType();
|
|
||||||
auto variance_type = input_args[2]->BuildType();
|
|
||||||
auto x_type = input_args[0]->BuildType();
|
|
||||||
auto global_step_type = input_args[3]->BuildType();
|
|
||||||
|
|
||||||
std::map<std::string, TypePtr> args = {{"x", x_type}, {"mean", mean_type}, {"variance", variance_type}};
|
|
||||||
auto element0 = CheckAndConvertUtils::CheckTensorTypeSame(args, {kFloat16, kFloat32}, op_name);
|
|
||||||
CheckAndConvertUtils::CheckTensorTypeValid("gloabal_step", global_step_type, {kInt32}, op_name);
|
|
||||||
|
|
||||||
auto output = std::make_shared<abstract::AbstractTensor>(element0, mean_shape);
|
|
||||||
AbstractBasePtrList output1 = {output, output, output, output};
|
|
||||||
return std::make_shared<abstract::AbstractTuple>(output1);
|
|
||||||
}
|
|
||||||
REGISTER_PRIMITIVE_C(kNameBatchNormFold, BatchNormFold);
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,54 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_BATCH_NORM_FOLD_H_
|
|
||||||
#define MINDSPORE_CORE_OPS_BATCH_NORM_FOLD_H_
|
|
||||||
#include <memory>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "ops/primitive_c.h"
|
|
||||||
#include "abstract/abstract_value.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
constexpr auto kNameBatchNormFold = "BatchNormFold";
|
|
||||||
class BatchNormFold : public PrimitiveC {
|
|
||||||
public:
|
|
||||||
BatchNormFold() : PrimitiveC(kNameBatchNormFold) {
|
|
||||||
InitIOName({"x", "mean", "variance", "global_step"}, {"batch_mean", "batch_std", "running_mean", "running_std"});
|
|
||||||
}
|
|
||||||
~BatchNormFold() = default;
|
|
||||||
MS_DECLARE_PARENT(BatchNormFold, PrimitiveC);
|
|
||||||
void Init(const float momentum = 0.9, const float epsilon = 1e-5, const bool is_training = true,
|
|
||||||
const int64_t freeze_bn = 0);
|
|
||||||
void set_momentum(const float momentum);
|
|
||||||
void set_epsilon(const float epsilon);
|
|
||||||
void set_is_training(const bool is_training);
|
|
||||||
void set_freeze_bn(const int64_t freeze_bn);
|
|
||||||
|
|
||||||
float get_momentum() const;
|
|
||||||
float get_epsilon() const;
|
|
||||||
bool get_is_training() const;
|
|
||||||
int64_t get_freeze_bn() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
AbstractBasePtr BatchNormFoldInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
|
||||||
using PrimBatchNormFoldPtr = std::shared_ptr<BatchNormFold>;
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // MINDSPORE_CORE_OPS_BATCH_NORM_FOLD_H_
|
|
|
@ -1,52 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "ops/black_box.h"
|
|
||||||
#include "ops/op_utils.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
#include "abstract/primitive_infer_map.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
void BlackBox::Init(const std::string &id, const int64_t size, const std::vector<int64_t> &address) {
|
|
||||||
this->set_id(id);
|
|
||||||
this->set_size(size);
|
|
||||||
this->set_address(address);
|
|
||||||
}
|
|
||||||
|
|
||||||
void BlackBox::set_id(const std::string &id) { this->AddAttr(kId, MakeValue(id)); }
|
|
||||||
|
|
||||||
std::string BlackBox::get_id() const {
|
|
||||||
auto value_ptr = this->GetAttr(kId);
|
|
||||||
return GetValue<std::string>(value_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
void BlackBox::set_size(const int64_t size) { this->AddAttr(kSize, MakeValue(size)); }
|
|
||||||
|
|
||||||
int64_t BlackBox::get_size() const {
|
|
||||||
auto value_ptr = this->GetAttr(kSize);
|
|
||||||
return GetValue<int64_t>(value_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
void BlackBox::set_address(const std::vector<int64_t> &address) { this->AddAttr(kAddress, MakeValue(address)); }
|
|
||||||
|
|
||||||
std::vector<int64_t> BlackBox::get_address() const {
|
|
||||||
auto value_ptr = this->GetAttr(kAddress);
|
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
|
||||||
}
|
|
||||||
REGISTER_PRIMITIVE_C(kNameBlackBox, BlackBox);
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,47 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_BLACK_BOX_H_
|
|
||||||
#define MINDSPORE_CORE_OPS_BLACK_BOX_H_
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
#include "ops/primitive_c.h"
|
|
||||||
#include "abstract/abstract_value.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
constexpr auto kNameBlackBox = "BlackBox";
|
|
||||||
class BlackBox : public PrimitiveC {
|
|
||||||
public:
|
|
||||||
BlackBox() : PrimitiveC(kNameBlackBox) {}
|
|
||||||
~BlackBox() = default;
|
|
||||||
MS_DECLARE_PARENT(BlackBox, PrimitiveC);
|
|
||||||
void Init(const std::string &id, const int64_t size, const std::vector<int64_t> &address);
|
|
||||||
void set_id(const std::string &id);
|
|
||||||
void set_size(const int64_t size);
|
|
||||||
void set_address(const std::vector<int64_t> &address);
|
|
||||||
std::string get_id() const;
|
|
||||||
int64_t get_size() const;
|
|
||||||
std::vector<int64_t> get_address() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
using PrimBlackBoxPtr = std::shared_ptr<BlackBox>;
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // MINDSPORE_CORE_OPS_BLACK_BOX_H_
|
|
|
@ -1,56 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include <map>
|
|
||||||
#include <string>
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#include "ops/constant.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
#include "abstract/primitive_infer_map.h"
|
|
||||||
#include "ops/op_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
namespace {
|
|
||||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
|
||||||
auto x = input_args[0]->BuildShape();
|
|
||||||
auto shape_element = x->cast<abstract::ShapePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(shape_element);
|
|
||||||
return shape_element;
|
|
||||||
}
|
|
||||||
|
|
||||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
|
||||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim->name());
|
|
||||||
for (const auto &item : input_args) {
|
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
|
||||||
}
|
|
||||||
std::map<std::string, TypePtr> types;
|
|
||||||
types.emplace("x", input_args[0]->BuildType());
|
|
||||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
AbstractBasePtr ConstantInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
|
||||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
|
||||||
InferShape(primitive, input_args)->shape());
|
|
||||||
}
|
|
||||||
REGISTER_PRIMITIVE_C(kNameConstant, Constant);
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,42 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_CONSTANT_H_
|
|
||||||
#define MINDSPORE_CORE_OPS_CONSTANT_H_
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
#include "ops/primitive_c.h"
|
|
||||||
#include "abstract/abstract_value.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
constexpr auto kNameConstant = "Constant";
|
|
||||||
class Constant : public PrimitiveC {
|
|
||||||
public:
|
|
||||||
Constant() : PrimitiveC(kNameConstant) {}
|
|
||||||
~Constant() = default;
|
|
||||||
MS_DECLARE_PARENT(Constant, PrimitiveC);
|
|
||||||
void Init() {}
|
|
||||||
};
|
|
||||||
|
|
||||||
AbstractBasePtr ConstantInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
|
||||||
using PrimConstantPtr = std::shared_ptr<Constant>;
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // MINDSPORE_CORE_OPS_CONSTANT_H_
|
|
|
@ -15,7 +15,6 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/crop_and_resize.h"
|
#include "ops/crop_and_resize.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
|
@ -23,17 +22,17 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
void CropAndResize::Init(const ResizeMethod method, const float extrapolation_value) {
|
void CropAndResize::Init(ResizeMethod method, float extrapolation_value) {
|
||||||
this->set_method(method);
|
this->set_method(method);
|
||||||
this->set_extrapolation_value(extrapolation_value);
|
this->set_extrapolation_value(extrapolation_value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CropAndResize::set_method(const ResizeMethod method) {
|
void CropAndResize::set_method(ResizeMethod method) {
|
||||||
auto swi = (int64_t)method;
|
auto swi = (int64_t)method;
|
||||||
this->AddAttr(kMethod, MakeValue(swi));
|
this->AddAttr(kMethod, MakeValue(swi));
|
||||||
}
|
}
|
||||||
|
|
||||||
void CropAndResize::set_extrapolation_value(const float extrapolation_value) {
|
void CropAndResize::set_extrapolation_value(float extrapolation_value) {
|
||||||
this->AddAttr(kExtrapolationValue, MakeValue(extrapolation_value));
|
this->AddAttr(kExtrapolationValue, MakeValue(extrapolation_value));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,17 +30,13 @@ class CropAndResize : public PrimitiveC {
|
||||||
CropAndResize() : PrimitiveC(kNameCropAndResize) { InitIOName({"x", "boxes", "box_index", "crop_size"}, {"y"}); }
|
CropAndResize() : PrimitiveC(kNameCropAndResize) { InitIOName({"x", "boxes", "box_index", "crop_size"}, {"y"}); }
|
||||||
~CropAndResize() = default;
|
~CropAndResize() = default;
|
||||||
MS_DECLARE_PARENT(CropAndResize, PrimitiveC);
|
MS_DECLARE_PARENT(CropAndResize, PrimitiveC);
|
||||||
void Init(const ResizeMethod method, const float extrapolation_value);
|
void Init(ResizeMethod method, float extrapolation_value);
|
||||||
|
|
||||||
void set_method(const ResizeMethod method);
|
|
||||||
void set_extrapolation_value(const float extrapolation_value);
|
|
||||||
|
|
||||||
|
void set_method(ResizeMethod method);
|
||||||
|
void set_extrapolation_value(float extrapolation_value);
|
||||||
ResizeMethod get_method() const;
|
ResizeMethod get_method() const;
|
||||||
float get_extrapolation_value() const;
|
float get_extrapolation_value() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr CropAndResizeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
|
||||||
using PrimCropAndResizePtr = std::shared_ptr<CropAndResize>;
|
using PrimCropAndResizePtr = std::shared_ptr<CropAndResize>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,215 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "ops/depthwise_conv2d.h"
|
|
||||||
#include <string>
|
|
||||||
#include <memory>
|
|
||||||
#include <vector>
|
|
||||||
#include <algorithm>
|
|
||||||
#include "ops/op_utils.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
#include "abstract/primitive_infer_map.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
void DepthWiseConv2D::Init(const int64_t channel_multiplier, const std::vector<int64_t> &kernel_size,
|
|
||||||
const int64_t mode, const PadMode &pad_mode, const std::vector<int64_t> &pad,
|
|
||||||
const std::vector<int64_t> &stride, const std::vector<int64_t> &dilation,
|
|
||||||
const int64_t group) {
|
|
||||||
auto prim_name = this->name();
|
|
||||||
this->set_format(NCHW);
|
|
||||||
this->AddAttr("offset_a", MakeValue(0));
|
|
||||||
this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 3, prim_name));
|
|
||||||
|
|
||||||
this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name));
|
|
||||||
auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name());
|
|
||||||
if (strides[0] != strides[1]) {
|
|
||||||
MS_EXCEPTION(ValueError) << "The height and width of stride should be equal, but got height " << strides[0]
|
|
||||||
<< ", width " << strides[1];
|
|
||||||
}
|
|
||||||
this->set_stride(strides);
|
|
||||||
auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name());
|
|
||||||
if (dilations[0] != dilations[1]) {
|
|
||||||
MS_EXCEPTION(ValueError) << "The height and width of dilation should be equal, but got height " << dilations[0]
|
|
||||||
<< ", width " << dilations[1];
|
|
||||||
}
|
|
||||||
this->set_dilation(dilations);
|
|
||||||
this->set_pad_mode(pad_mode);
|
|
||||||
|
|
||||||
CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, prim_name);
|
|
||||||
if (pad_mode == PAD) {
|
|
||||||
for (auto item : pad) {
|
|
||||||
CheckAndConvertUtils::Check("pad_item", item, kGreaterEqual, "zeros_list", 0, prim_name);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name);
|
|
||||||
}
|
|
||||||
this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name()));
|
|
||||||
|
|
||||||
this->set_out_channel(
|
|
||||||
CheckAndConvertUtils::CheckInteger("channel_multiplier", channel_multiplier, kGreaterThan, 0, prim_name));
|
|
||||||
this->set_group(CheckAndConvertUtils::CheckInteger("group", group, kGreaterThan, 0, prim_name));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int64_t> DepthWiseConv2D::get_kernel_size() const {
|
|
||||||
return GetValue<std::vector<int64_t>>(GetAttr(kKernelSize));
|
|
||||||
}
|
|
||||||
std::vector<int64_t> DepthWiseConv2D::get_stride() const { return GetValue<std::vector<int64_t>>(GetAttr(kStride)); }
|
|
||||||
std::vector<int64_t> DepthWiseConv2D::get_dilation() const {
|
|
||||||
return GetValue<std::vector<int64_t>>(GetAttr(kDilation));
|
|
||||||
}
|
|
||||||
PadMode DepthWiseConv2D::get_pad_mode() const { return PadMode(GetValue<int64_t>(GetAttr(kPadMode))); }
|
|
||||||
std::vector<int64_t> DepthWiseConv2D::get_pad() const { return GetValue<std::vector<int64_t>>(GetAttr(kPad)); }
|
|
||||||
|
|
||||||
std::vector<int64_t> DepthWiseConv2D::get_pads() const {
|
|
||||||
auto value_ptr = this->GetAttr(kPads);
|
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t DepthWiseConv2D::get_mode() const {
|
|
||||||
auto value_ptr = this->GetAttr(kMode);
|
|
||||||
return GetValue<int64_t>(value_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t DepthWiseConv2D::get_group() const {
|
|
||||||
auto value_ptr = this->GetAttr(kGroup);
|
|
||||||
return GetValue<int64_t>(value_ptr);
|
|
||||||
}
|
|
||||||
int64_t DepthWiseConv2D::get_out_channel() const { return GetValue<int64_t>(GetAttr(kOutChannel)); }
|
|
||||||
|
|
||||||
void DepthWiseConv2D::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
|
||||||
this->AddAttr(kKernelSize, MakeValue(kernel_size));
|
|
||||||
}
|
|
||||||
|
|
||||||
void DepthWiseConv2D::set_stride(const std::vector<int64_t> &stride) { this->AddAttr(kStride, MakeValue(stride)); }
|
|
||||||
void DepthWiseConv2D::set_dilation(const std::vector<int64_t> &dilation) {
|
|
||||||
this->AddAttr(kDilation, MakeValue(dilation));
|
|
||||||
}
|
|
||||||
void DepthWiseConv2D::set_pad_mode(const PadMode &pad_mode) {
|
|
||||||
int64_t swi = pad_mode;
|
|
||||||
this->AddAttr(kPadMode, MakeValue(swi));
|
|
||||||
}
|
|
||||||
void DepthWiseConv2D::set_pad(const std::vector<int64_t> &pad) { this->AddAttr(kPad, MakeValue(pad)); }
|
|
||||||
void DepthWiseConv2D::set_mode(const int64_t mode) { this->AddAttr(kMode, MakeValue(mode)); }
|
|
||||||
void DepthWiseConv2D::set_group(const int64_t group) { this->AddAttr(kGroup, MakeValue(group)); }
|
|
||||||
void DepthWiseConv2D::set_out_channel(const int64_t out_channel) { this->AddAttr(kOutChannel, MakeValue(out_channel)); }
|
|
||||||
void DepthWiseConv2D::set_pads(const std::vector<int64_t> &pad_list) { this->AddAttr(kPads, MakeValue(pad_list)); }
|
|
||||||
void DepthWiseConv2D::set_format(const Format &format) {
|
|
||||||
int64_t f = format;
|
|
||||||
this->AddAttr(kFormat, MakeValue(f));
|
|
||||||
}
|
|
||||||
|
|
||||||
Format DepthWiseConv2D::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); }
|
|
||||||
|
|
||||||
abstract::ShapePtr DepthWiseConv2DInferShape(const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
|
||||||
auto prim_name = primitive->name();
|
|
||||||
CheckAndConvertUtils::CheckInRange<size_t>("conv2d_Infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
|
|
||||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
|
||||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape];
|
|
||||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
|
||||||
if (format == NHWC) {
|
|
||||||
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};
|
|
||||||
w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]};
|
|
||||||
}
|
|
||||||
CheckAndConvertUtils::CheckInteger("weight_rank", w_shape.size(), kEqual, 4, prim_name);
|
|
||||||
CheckAndConvertUtils::CheckInteger("x_rank", x_shape.size(), kEqual, 4, prim_name);
|
|
||||||
CheckAndConvertUtils::Check("x_shape[1]", x_shape[1], kEqual, "w_shape[1]", w_shape[1], prim_name);
|
|
||||||
auto out_channel = GetValue<int64_t>(primitive->GetAttr(kOutChannel));
|
|
||||||
|
|
||||||
std::vector<int64_t> temp_w;
|
|
||||||
std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
|
|
||||||
CheckAndConvertUtils::Check("kernel_size", GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)), kEqual,
|
|
||||||
"w_shape[2:4]", temp_w, prim_name);
|
|
||||||
|
|
||||||
auto kernel_size_n = w_shape[0];
|
|
||||||
if (kernel_size_n != 1) {
|
|
||||||
MS_EXCEPTION(ValueError) << "The batch of input weeight should be 1, but got " << kernel_size_n;
|
|
||||||
}
|
|
||||||
auto kernel_size_h = w_shape[2];
|
|
||||||
auto kernel_size_w = w_shape[3];
|
|
||||||
auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride));
|
|
||||||
auto dilation = GetValue<std::vector<int64_t>>(primitive->GetAttr(kDilation));
|
|
||||||
auto stride_h = stride[2];
|
|
||||||
auto stride_w = stride[3];
|
|
||||||
auto dilation_h = dilation[2];
|
|
||||||
auto dilation_w = dilation[3];
|
|
||||||
int64_t h_out = -1;
|
|
||||||
int64_t w_out = -1;
|
|
||||||
std::vector<int64_t> pad_list(4, 0);
|
|
||||||
auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
|
|
||||||
if (pad_mode == VALID) {
|
|
||||||
h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h);
|
|
||||||
w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w);
|
|
||||||
} else if (pad_mode == SAME) {
|
|
||||||
h_out = ceil(x_shape[2] / stride_h);
|
|
||||||
w_out = ceil(x_shape[3] / stride_w);
|
|
||||||
|
|
||||||
auto pad_needed_h =
|
|
||||||
std::max(static_cast<int64_t>(0), (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]);
|
|
||||||
pad_list.emplace_back(floor(pad_needed_h / 2));
|
|
||||||
pad_list.emplace_back(pad_needed_h / 2);
|
|
||||||
auto pad_needed_w =
|
|
||||||
std::max(static_cast<int64_t>(0), (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]);
|
|
||||||
auto pad_left = floor(pad_needed_w / 2);
|
|
||||||
pad_list.emplace_back(pad_left);
|
|
||||||
pad_list.emplace_back(pad_needed_h - pad_left);
|
|
||||||
} else if (pad_mode == PAD) {
|
|
||||||
auto pads = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad));
|
|
||||||
std::copy(pads.begin(), pads.end(), std::back_inserter(pad_list));
|
|
||||||
auto pad_top = pads[0];
|
|
||||||
auto pad_bottom = pads[1];
|
|
||||||
auto pad_right = pads[2];
|
|
||||||
auto pad_left = pads[3];
|
|
||||||
|
|
||||||
h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h;
|
|
||||||
w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w;
|
|
||||||
h_out = floor(h_out);
|
|
||||||
w_out = floor(w_out);
|
|
||||||
}
|
|
||||||
primitive->AddAttr(kPads, MakeValue(pad_list));
|
|
||||||
std::vector<int64_t> out_shape = {x_shape[0], out_channel * x_shape[1], h_out, w_out};
|
|
||||||
if (format == NHWC) {
|
|
||||||
out_shape = {x_shape[0], h_out, w_out, out_channel * x_shape[1]};
|
|
||||||
}
|
|
||||||
return std::make_shared<abstract::Shape>(out_shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
TypePtr DepthWiseConv2DInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
|
||||||
CheckAndConvertUtils::CheckInRange<size_t>("", input_args.size(), kIncludeBoth, {2, 3}, prim->name());
|
|
||||||
for (const auto &item : input_args) {
|
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::map<std::string, TypePtr> types;
|
|
||||||
types.emplace("x", input_args[0]->BuildType());
|
|
||||||
types.emplace("w", input_args[1]->BuildType());
|
|
||||||
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
|
||||||
if (*infer_type == *kInt8) {
|
|
||||||
return kInt32;
|
|
||||||
}
|
|
||||||
return infer_type;
|
|
||||||
}
|
|
||||||
|
|
||||||
AbstractBasePtr DepthWiseConv2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
|
||||||
return std::make_shared<abstract::AbstractTensor>(DepthWiseConv2DInferType(primitive, input_args),
|
|
||||||
DepthWiseConv2DInferShape(primitive, input_args)->shape());
|
|
||||||
}
|
|
||||||
REGISTER_PRIMITIVE_C(kNameDepthWiseConv2D, DepthWiseConv2D);
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,67 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_DEPTHWISE_CONV2D_H
|
|
||||||
#define MINDSPORE_CORE_OPS_DEPTHWISE_CONV2D_H
|
|
||||||
#include <map>
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <memory>
|
|
||||||
#include "ops/primitive_c.h"
|
|
||||||
#include "abstract/abstract_value.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
constexpr auto kNameDepthWiseConv2D = "DepthwiseConv2dNative";
|
|
||||||
class DepthWiseConv2D : public PrimitiveC {
|
|
||||||
public:
|
|
||||||
DepthWiseConv2D() : PrimitiveC(kNameDepthWiseConv2D) { InitIOName({"x", "w"}, {"output"}); }
|
|
||||||
explicit DepthWiseConv2D(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x", "w"}, {"output"}); }
|
|
||||||
~DepthWiseConv2D() = default;
|
|
||||||
MS_DECLARE_PARENT(DepthWiseConv2D, PrimitiveC);
|
|
||||||
void Init(const int64_t out_channel, const std::vector<int64_t> &kernel_size, const int64_t mode = 1,
|
|
||||||
const PadMode &pad_mode = VALID, const std::vector<int64_t> &pad = {0, 0, 0, 0},
|
|
||||||
const std::vector<int64_t> &stride = {1, 1, 1, 1}, const std::vector<int64_t> &dilation = {1, 1, 1, 1},
|
|
||||||
const int64_t group = 1);
|
|
||||||
void set_kernel_size(const std::vector<int64_t> &kernel_size);
|
|
||||||
void set_stride(const std::vector<int64_t> &stride);
|
|
||||||
void set_dilation(const std::vector<int64_t> &dilation);
|
|
||||||
void set_pad_mode(const PadMode &pad_mode);
|
|
||||||
void set_pad(const std::vector<int64_t> &pad);
|
|
||||||
void set_mode(const int64_t mode);
|
|
||||||
void set_group(const int64_t group);
|
|
||||||
void set_out_channel(const int64_t out_channel);
|
|
||||||
void set_pads(const std::vector<int64_t> &pad_list);
|
|
||||||
void set_format(const Format &format);
|
|
||||||
std::vector<int64_t> get_kernel_size() const;
|
|
||||||
std::vector<int64_t> get_stride() const;
|
|
||||||
std::vector<int64_t> get_dilation() const;
|
|
||||||
PadMode get_pad_mode() const;
|
|
||||||
std::vector<int64_t> get_pad() const;
|
|
||||||
std::vector<int64_t> get_pads() const;
|
|
||||||
int64_t get_mode() const;
|
|
||||||
int64_t get_group() const;
|
|
||||||
int64_t get_out_channel() const;
|
|
||||||
Format get_format() const;
|
|
||||||
};
|
|
||||||
AbstractBasePtr DepthWiseConv2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
|
||||||
using PrimDepthWiseConv2DPtr = std::shared_ptr<DepthWiseConv2D>;
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // MINDSPORE_CORE_OPS_DEPTHWISE_CONV2D_H
|
|
|
@ -14,12 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <set>
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/erf.h"
|
#include "ops/erf.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
#include "ops/op_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
|
@ -30,11 +30,8 @@ class Erf : public PrimitiveC {
|
||||||
Erf() : PrimitiveC(kNameErf) { InitIOName({"x"}, {"y"}); }
|
Erf() : PrimitiveC(kNameErf) { InitIOName({"x"}, {"y"}); }
|
||||||
~Erf() = default;
|
~Erf() = default;
|
||||||
MS_DECLARE_PARENT(Erf, PrimitiveC);
|
MS_DECLARE_PARENT(Erf, PrimitiveC);
|
||||||
void Init() {}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr ErfInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
|
||||||
using PrimErfPtr = std::shared_ptr<Erf>;
|
using PrimErfPtr = std::shared_ptr<Erf>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -23,7 +23,7 @@ void Conv2dTransposeFusion::Init(int64_t in_channel, int64_t out_channel, const
|
||||||
int64_t mode, const PadMode &pad_mode, const std::vector<int64_t> &pad,
|
int64_t mode, const PadMode &pad_mode, const std::vector<int64_t> &pad,
|
||||||
const std::vector<int64_t> &stride, const std::vector<int64_t> &dilation,
|
const std::vector<int64_t> &stride, const std::vector<int64_t> &dilation,
|
||||||
int64_t group, const Format &format, const std::vector<int64_t> &pad_list,
|
int64_t group, const Format &format, const std::vector<int64_t> &pad_list,
|
||||||
const std::vector<int64_t> &output_paddings, const ActivationType activation_type) {
|
const std::vector<int64_t> &output_paddings, ActivationType activation_type) {
|
||||||
set_in_channel(in_channel);
|
set_in_channel(in_channel);
|
||||||
set_out_channel(out_channel);
|
set_out_channel(out_channel);
|
||||||
set_kernel_size(kernel_size);
|
set_kernel_size(kernel_size);
|
||||||
|
@ -56,20 +56,20 @@ void Conv2dTransposeFusion::set_dilation(const std::vector<int64_t> &dilation) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void Conv2dTransposeFusion::set_output_paddings(const std::vector<int64_t> &output_paddings) {
|
void Conv2dTransposeFusion::set_output_paddings(const std::vector<int64_t> &output_paddings) {
|
||||||
CheckAndConvertUtils::CheckInteger(koutputPaddings, output_paddings.size(), kGreaterEqual, 1, name());
|
CheckAndConvertUtils::CheckInteger(kOutputPaddings, output_paddings.size(), kGreaterEqual, 1, name());
|
||||||
for (int64_t item : output_paddings) {
|
for (int64_t item : output_paddings) {
|
||||||
CheckAndConvertUtils::CheckInteger(koutputPaddings, item, kGreaterEqual, 0, name());
|
CheckAndConvertUtils::CheckInteger(kOutputPaddings, item, kGreaterEqual, 0, name());
|
||||||
}
|
}
|
||||||
AddAttr(koutputPaddings, MakeValue(output_paddings));
|
AddAttr(kOutputPaddings, MakeValue(output_paddings));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Conv2dTransposeFusion::set_activation_type(const ActivationType activation_type) {
|
void Conv2dTransposeFusion::set_activation_type(ActivationType activation_type) {
|
||||||
int64_t swi = activation_type;
|
int64_t swi = activation_type;
|
||||||
this->AddAttr(kActivationType, MakeValue(swi));
|
this->AddAttr(kActivationType, MakeValue(swi));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> Conv2dTransposeFusion::get_output_paddings() const {
|
std::vector<int64_t> Conv2dTransposeFusion::get_output_paddings() const {
|
||||||
auto value_ptr = GetAttr(koutputPaddings);
|
auto value_ptr = GetAttr(kOutputPaddings);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -36,11 +36,11 @@ class Conv2dTransposeFusion : public Conv2dTranspose {
|
||||||
const PadMode &pad_mode = VALID, const std::vector<int64_t> &pad = {0, 0, 0, 0},
|
const PadMode &pad_mode = VALID, const std::vector<int64_t> &pad = {0, 0, 0, 0},
|
||||||
const std::vector<int64_t> &stride = {1, 1}, const std::vector<int64_t> &dilation = {1, 1},
|
const std::vector<int64_t> &stride = {1, 1}, const std::vector<int64_t> &dilation = {1, 1},
|
||||||
int64_t group = 1, const Format &format = NCHW, const std::vector<int64_t> &pad_list = {0, 0, 0, 0},
|
int64_t group = 1, const Format &format = NCHW, const std::vector<int64_t> &pad_list = {0, 0, 0, 0},
|
||||||
const std::vector<int64_t> &output_paddings = {0}, const ActivationType activation_type = NO_ACTIVATION);
|
const std::vector<int64_t> &output_paddings = {0}, ActivationType activation_type = NO_ACTIVATION);
|
||||||
void set_kernel_size(const std::vector<int64_t> &kernel_size);
|
void set_kernel_size(const std::vector<int64_t> &kernel_size) override;
|
||||||
void set_dilation(const std::vector<int64_t> &dilation);
|
void set_dilation(const std::vector<int64_t> &dilation) override;
|
||||||
void set_output_paddings(const std::vector<int64_t> &output_paddings);
|
void set_output_paddings(const std::vector<int64_t> &output_paddings);
|
||||||
void set_activation_type(const ActivationType activation_type);
|
void set_activation_type(ActivationType activation_type);
|
||||||
|
|
||||||
std::vector<int64_t> get_output_paddings() const;
|
std::vector<int64_t> get_output_paddings() const;
|
||||||
ActivationType get_activation_type() const;
|
ActivationType get_activation_type() const;
|
||||||
|
|
|
@ -1,75 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "ops/fusion/depthwise_conv2d_fusion.h"
|
|
||||||
#include <string>
|
|
||||||
#include "ops/op_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
void DepthWiseConv2DFusion::Init(const int64_t channel_multiplier, const std::vector<int64_t> &kernel_size,
|
|
||||||
const int64_t mode, const PadMode &pad_mode, const std::vector<int64_t> &pad,
|
|
||||||
const std::vector<int64_t> &stride, const std::vector<int64_t> &dilation,
|
|
||||||
const int64_t group, const ActivationType &activation_type) {
|
|
||||||
auto prim_name = this->name();
|
|
||||||
this->set_format(NCHW);
|
|
||||||
this->AddAttr("offset_a", MakeValue(0));
|
|
||||||
this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 3, prim_name));
|
|
||||||
|
|
||||||
this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name));
|
|
||||||
auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name());
|
|
||||||
if (strides[0] != strides[1]) {
|
|
||||||
MS_EXCEPTION(ValueError) << "The height and width of stride should be equal, but got height " << strides[0]
|
|
||||||
<< ", width " << strides[1];
|
|
||||||
}
|
|
||||||
this->set_stride(strides);
|
|
||||||
auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name());
|
|
||||||
if (dilations[0] != dilations[1]) {
|
|
||||||
MS_EXCEPTION(ValueError) << "The height and width of dilation should be equal, but got height " << dilations[0]
|
|
||||||
<< ", width " << dilations[1];
|
|
||||||
}
|
|
||||||
this->set_dilation(dilations);
|
|
||||||
this->set_pad_mode(pad_mode);
|
|
||||||
|
|
||||||
CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, prim_name);
|
|
||||||
if (pad_mode == PAD) {
|
|
||||||
for (auto item : pad) {
|
|
||||||
CheckAndConvertUtils::Check("pad_item", item, kGreaterEqual, "zeros_list", 0, prim_name);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name);
|
|
||||||
}
|
|
||||||
this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name()));
|
|
||||||
|
|
||||||
this->set_out_channel(
|
|
||||||
CheckAndConvertUtils::CheckInteger("channel_multiplier", channel_multiplier, kGreaterThan, 0, prim_name));
|
|
||||||
this->set_group(CheckAndConvertUtils::CheckInteger("group", group, kGreaterThan, 0, prim_name));
|
|
||||||
this->set_activation_type(activation_type);
|
|
||||||
}
|
|
||||||
|
|
||||||
void DepthWiseConv2DFusion::set_activation_type(const ActivationType &activation_type) {
|
|
||||||
int64_t swi;
|
|
||||||
swi = activation_type;
|
|
||||||
this->AddAttr(kActivationType, MakeValue(swi));
|
|
||||||
}
|
|
||||||
|
|
||||||
ActivationType DepthWiseConv2DFusion::get_activation_type() const {
|
|
||||||
auto value_ptr = GetAttr(kActivationType);
|
|
||||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
|
||||||
}
|
|
||||||
REGISTER_PRIMITIVE_C(kNameDepthWiseConv2DFusion, DepthWiseConv2DFusion);
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,41 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_DEPTHWISE_CONV2D_FUSION_H_
|
|
||||||
#define MINDSPORE_CORE_OPS_DEPTHWISE_CONV2D_FUSION_H_
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "ops/depthwise_conv2d.h"
|
|
||||||
#include "ops/op_utils.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
constexpr auto kNameDepthWiseConv2DFusion = "DepthWiseConv2DFusion";
|
|
||||||
class DepthWiseConv2DFusion : public DepthWiseConv2D {
|
|
||||||
public:
|
|
||||||
MS_DECLARE_PARENT(DepthWiseConv2DFusion, DepthWiseConv2D);
|
|
||||||
void Init(const int64_t out_channel, const std::vector<int64_t> &kernel_size, const int64_t mode = 1,
|
|
||||||
const PadMode &pad_mode = VALID, const std::vector<int64_t> &pad = {0, 0, 0, 0},
|
|
||||||
const std::vector<int64_t> &stride = {1, 1, 1, 1}, const std::vector<int64_t> &dilation = {1, 1, 1, 1},
|
|
||||||
const int64_t group = 1, const ActivationType &activation_type = NO_ACTIVATION);
|
|
||||||
void set_activation_type(const ActivationType &activation_type);
|
|
||||||
ActivationType get_activation_type() const;
|
|
||||||
};
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // MINDSPORE_CORE_OPS_DEPTHWISE_CONV2D_FUSION_H_
|
|
|
@ -32,7 +32,6 @@ class AbsGrad : public PrimitiveC {
|
||||||
AbsGrad() : PrimitiveC(kNameAbsGrad) {}
|
AbsGrad() : PrimitiveC(kNameAbsGrad) {}
|
||||||
~AbsGrad() = default;
|
~AbsGrad() = default;
|
||||||
MS_DECLARE_PARENT(AbsGrad, PrimitiveC);
|
MS_DECLARE_PARENT(AbsGrad, PrimitiveC);
|
||||||
void Init() {}
|
|
||||||
};
|
};
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -23,8 +23,8 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
void StridedSliceGrad::Init(const int64_t begin_mask, const int64_t end_mask, const int64_t ellipsis_mask,
|
void StridedSliceGrad::Init(int64_t begin_mask, int64_t end_mask, int64_t ellipsis_mask, int64_t new_axis_mask,
|
||||||
const int64_t new_axis_mask, const int64_t shrink_axis_mask) {
|
int64_t shrink_axis_mask) {
|
||||||
this->set_begin_mask(begin_mask);
|
this->set_begin_mask(begin_mask);
|
||||||
this->set_end_mask(end_mask);
|
this->set_end_mask(end_mask);
|
||||||
this->set_ellipsis_mask(ellipsis_mask);
|
this->set_ellipsis_mask(ellipsis_mask);
|
||||||
|
@ -32,7 +32,7 @@ void StridedSliceGrad::Init(const int64_t begin_mask, const int64_t end_mask, co
|
||||||
this->set_shrink_axis_mask(shrink_axis_mask);
|
this->set_shrink_axis_mask(shrink_axis_mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
void StridedSliceGrad::set_begin_mask(const int64_t begin_mask) {
|
void StridedSliceGrad::set_begin_mask(int64_t begin_mask) {
|
||||||
CheckAndConvertUtils::CheckInteger(kBeginMask, begin_mask, kGreaterEqual, 0, this->name());
|
CheckAndConvertUtils::CheckInteger(kBeginMask, begin_mask, kGreaterEqual, 0, this->name());
|
||||||
this->AddAttr(kBeginMask, MakeValue(begin_mask));
|
this->AddAttr(kBeginMask, MakeValue(begin_mask));
|
||||||
}
|
}
|
||||||
|
@ -40,7 +40,7 @@ int64_t StridedSliceGrad::get_begin_mask() const {
|
||||||
auto value_ptr = GetAttr(kBeginMask);
|
auto value_ptr = GetAttr(kBeginMask);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
void StridedSliceGrad::set_end_mask(const int64_t end_mask) {
|
void StridedSliceGrad::set_end_mask(int64_t end_mask) {
|
||||||
CheckAndConvertUtils::CheckInteger(kEndMask, end_mask, kGreaterEqual, 0, this->name());
|
CheckAndConvertUtils::CheckInteger(kEndMask, end_mask, kGreaterEqual, 0, this->name());
|
||||||
this->AddAttr(kEndMask, MakeValue(end_mask));
|
this->AddAttr(kEndMask, MakeValue(end_mask));
|
||||||
}
|
}
|
||||||
|
@ -48,7 +48,7 @@ int64_t StridedSliceGrad::get_end_mask() const {
|
||||||
auto value_ptr = GetAttr(kEndMask);
|
auto value_ptr = GetAttr(kEndMask);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
void StridedSliceGrad::set_ellipsis_mask(const int64_t ellipsis_mask) {
|
void StridedSliceGrad::set_ellipsis_mask(int64_t ellipsis_mask) {
|
||||||
CheckAndConvertUtils::CheckInteger(kEllipsisMask, ellipsis_mask, kGreaterEqual, 0, this->name());
|
CheckAndConvertUtils::CheckInteger(kEllipsisMask, ellipsis_mask, kGreaterEqual, 0, this->name());
|
||||||
std::bitset<sizeof(int64_t) * 8> bs(ellipsis_mask);
|
std::bitset<sizeof(int64_t) * 8> bs(ellipsis_mask);
|
||||||
std::ostringstream buffer;
|
std::ostringstream buffer;
|
||||||
|
@ -62,7 +62,7 @@ int64_t StridedSliceGrad::get_ellipsis_mask() const {
|
||||||
auto value_ptr = GetAttr(kEllipsisMask);
|
auto value_ptr = GetAttr(kEllipsisMask);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
void StridedSliceGrad::set_new_axis_mask(const int64_t new_axis_mask) {
|
void StridedSliceGrad::set_new_axis_mask(int64_t new_axis_mask) {
|
||||||
CheckAndConvertUtils::CheckInteger(kNewAxisMask, new_axis_mask, kGreaterEqual, 0, this->name());
|
CheckAndConvertUtils::CheckInteger(kNewAxisMask, new_axis_mask, kGreaterEqual, 0, this->name());
|
||||||
this->AddAttr(kNewAxisMask, MakeValue(new_axis_mask));
|
this->AddAttr(kNewAxisMask, MakeValue(new_axis_mask));
|
||||||
}
|
}
|
||||||
|
@ -70,7 +70,7 @@ int64_t StridedSliceGrad::get_new_axis_mask() const {
|
||||||
auto value_ptr = GetAttr(kNewAxisMask);
|
auto value_ptr = GetAttr(kNewAxisMask);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
void StridedSliceGrad::set_shrink_axis_mask(const int64_t shrink_axis_mask) {
|
void StridedSliceGrad::set_shrink_axis_mask(int64_t shrink_axis_mask) {
|
||||||
CheckAndConvertUtils::CheckInteger(kShrinkAxisMask, shrink_axis_mask, kGreaterEqual, 0, this->name());
|
CheckAndConvertUtils::CheckInteger(kShrinkAxisMask, shrink_axis_mask, kGreaterEqual, 0, this->name());
|
||||||
this->AddAttr(kShrinkAxisMask, MakeValue(shrink_axis_mask));
|
this->AddAttr(kShrinkAxisMask, MakeValue(shrink_axis_mask));
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,13 +33,13 @@ class StridedSliceGrad : public PrimitiveC {
|
||||||
StridedSliceGrad() : PrimitiveC(kNameStridedSliceGrad) {}
|
StridedSliceGrad() : PrimitiveC(kNameStridedSliceGrad) {}
|
||||||
~StridedSliceGrad() = default;
|
~StridedSliceGrad() = default;
|
||||||
MS_DECLARE_PARENT(StridedSliceGrad, PrimitiveC);
|
MS_DECLARE_PARENT(StridedSliceGrad, PrimitiveC);
|
||||||
void Init(const int64_t begin_mask = 0, const int64_t end_mask = 0, const int64_t ellipsis_mask = 0,
|
void Init(int64_t begin_mask = 0, int64_t end_mask = 0, int64_t ellipsis_mask = 0, int64_t new_axis_mask = 0,
|
||||||
const int64_t new_axis_mask = 0, const int64_t shrink_axis_mask = 0);
|
int64_t shrink_axis_mask = 0);
|
||||||
void set_begin_mask(const int64_t begin_mask);
|
void set_begin_mask(int64_t begin_mask);
|
||||||
void set_end_mask(const int64_t end_mask);
|
void set_end_mask(int64_t end_mask);
|
||||||
void set_ellipsis_mask(const int64_t ellipsis_mask);
|
void set_ellipsis_mask(int64_t ellipsis_mask);
|
||||||
void set_new_axis_mask(const int64_t new_axis_mask);
|
void set_new_axis_mask(int64_t new_axis_mask);
|
||||||
void set_shrink_axis_mask(const int64_t shrink_axis_mask);
|
void set_shrink_axis_mask(int64_t shrink_axis_mask);
|
||||||
int64_t get_begin_mask() const;
|
int64_t get_begin_mask() const;
|
||||||
int64_t get_end_mask() const;
|
int64_t get_end_mask() const;
|
||||||
int64_t get_ellipsis_mask() const;
|
int64_t get_ellipsis_mask() const;
|
||||||
|
|
|
@ -18,9 +18,9 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
void GRU::Init(const bool bidirectional, const int64_t cell_depth, const float keep_prob, const float cell_clip,
|
void GRU::Init(bool bidirectional, int64_t cell_depth, float keep_prob, float cell_clip, int64_t num_proj,
|
||||||
const int64_t num_proj, const bool time_major, const bool reset_after, const bool is_training,
|
bool time_major, bool reset_after, bool is_training, ActivationType activation,
|
||||||
const ActivationType activation, const GateOrderMode gate_order) {
|
GateOrderMode gate_order) {
|
||||||
this->set_bidirectional(bidirectional);
|
this->set_bidirectional(bidirectional);
|
||||||
this->set_cell_depth(cell_depth);
|
this->set_cell_depth(cell_depth);
|
||||||
this->set_keep_prob(keep_prob);
|
this->set_keep_prob(keep_prob);
|
||||||
|
@ -33,31 +33,31 @@ void GRU::Init(const bool bidirectional, const int64_t cell_depth, const float k
|
||||||
this->set_gate_order(gate_order);
|
this->set_gate_order(gate_order);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GRU::set_bidirectional(const bool bidirectional) { AddAttr(kBidirectional, MakeValue(bidirectional)); }
|
void GRU::set_bidirectional(bool bidirectional) { AddAttr(kBidirectional, MakeValue(bidirectional)); }
|
||||||
|
|
||||||
void GRU::set_cell_depth(const int64_t cell_depth) { AddAttr(kCellDepth, MakeValue(cell_depth)); }
|
void GRU::set_cell_depth(int64_t cell_depth) { AddAttr(kCellDepth, MakeValue(cell_depth)); }
|
||||||
|
|
||||||
void GRU::set_keep_prob(const float keep_prob) { AddAttr(kKeepProb, MakeValue(keep_prob)); }
|
void GRU::set_keep_prob(float keep_prob) { AddAttr(kKeepProb, MakeValue(keep_prob)); }
|
||||||
|
|
||||||
void GRU::set_cell_clip(const float cell_clip) { AddAttr(kCellClip, MakeValue(cell_clip)); }
|
void GRU::set_cell_clip(float cell_clip) { AddAttr(kCellClip, MakeValue(cell_clip)); }
|
||||||
|
|
||||||
void GRU::set_num_proj(const int64_t num_proj) {
|
void GRU::set_num_proj(int64_t num_proj) {
|
||||||
CheckAndConvertUtils::CheckInteger(kNumProj, num_proj, kGreaterThan, 0, this->name());
|
CheckAndConvertUtils::CheckInteger(kNumProj, num_proj, kGreaterThan, 0, this->name());
|
||||||
AddAttr(kNumProj, MakeValue(num_proj));
|
AddAttr(kNumProj, MakeValue(num_proj));
|
||||||
}
|
}
|
||||||
|
|
||||||
void GRU::set_time_major(const bool time_major) { AddAttr(kTimeMajor, MakeValue(time_major)); }
|
void GRU::set_time_major(bool time_major) { AddAttr(kTimeMajor, MakeValue(time_major)); }
|
||||||
|
|
||||||
void GRU::set_reset_after(const bool reset_after) { AddAttr(kResetAfter, MakeValue(reset_after)); }
|
void GRU::set_reset_after(bool reset_after) { AddAttr(kResetAfter, MakeValue(reset_after)); }
|
||||||
|
|
||||||
void GRU::set_is_training(const bool is_training) { AddAttr(kIsTraining, MakeValue(is_training)); }
|
void GRU::set_is_training(bool is_training) { AddAttr(kIsTraining, MakeValue(is_training)); }
|
||||||
|
|
||||||
void GRU::set_activation(const ActivationType activation) {
|
void GRU::set_activation(ActivationType activation) {
|
||||||
int64_t swi = activation;
|
int64_t swi = activation;
|
||||||
AddAttr(kActivation, MakeValue(swi));
|
AddAttr(kActivation, MakeValue(swi));
|
||||||
}
|
}
|
||||||
|
|
||||||
void GRU::set_gate_order(const GateOrderMode gate_order) {
|
void GRU::set_gate_order(GateOrderMode gate_order) {
|
||||||
int64_t swi = gate_order;
|
int64_t swi = gate_order;
|
||||||
AddAttr(kGateOrder, MakeValue(swi));
|
AddAttr(kGateOrder, MakeValue(swi));
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,22 +39,20 @@ class GRU : public PrimitiveC {
|
||||||
}
|
}
|
||||||
~GRU() = default;
|
~GRU() = default;
|
||||||
MS_DECLARE_PARENT(GRU, PrimitiveC);
|
MS_DECLARE_PARENT(GRU, PrimitiveC);
|
||||||
void Init(const bool bidirectional = false, const int64_t cell_depth = 1, const float keep_prob = 1.0,
|
void Init(bool bidirectional = false, int64_t cell_depth = 1, float keep_prob = 1.0, float cell_clip = -1.0,
|
||||||
const float cell_clip = -1.0, const int64_t num_proj = 0, const bool time_major = true,
|
int64_t num_proj = 0, bool time_major = true, bool reset_after = true, bool is_training = true,
|
||||||
const bool reset_after = true, const bool is_training = true,
|
ActivationType activation = ActivationType::TANH, GateOrderMode gate_order = GateOrderMode::RZH);
|
||||||
const ActivationType activation = ActivationType::TANH,
|
|
||||||
const GateOrderMode gate_order = GateOrderMode::RZH);
|
|
||||||
|
|
||||||
void set_bidirectional(const bool bidirectional);
|
void set_bidirectional(bool bidirectional);
|
||||||
void set_cell_depth(const int64_t cell_depth);
|
void set_cell_depth(int64_t cell_depth);
|
||||||
void set_keep_prob(const float keep_prob);
|
void set_keep_prob(float keep_prob);
|
||||||
void set_cell_clip(const float cell_clip);
|
void set_cell_clip(float cell_clip);
|
||||||
void set_num_proj(const int64_t num_proj);
|
void set_num_proj(int64_t num_proj);
|
||||||
void set_time_major(const bool time_major);
|
void set_time_major(bool time_major);
|
||||||
void set_reset_after(const bool reset_after);
|
void set_reset_after(bool reset_after);
|
||||||
void set_is_training(const bool is_training);
|
void set_is_training(bool is_training);
|
||||||
void set_activation(const ActivationType activation);
|
void set_activation(ActivationType activation);
|
||||||
void set_gate_order(const GateOrderMode gate_order);
|
void set_gate_order(GateOrderMode gate_order);
|
||||||
|
|
||||||
bool get_bidirectional() const;
|
bool get_bidirectional() const;
|
||||||
int64_t get_cell_depth() const;
|
int64_t get_cell_depth() const;
|
||||||
|
@ -68,8 +66,6 @@ class GRU : public PrimitiveC {
|
||||||
GateOrderMode get_gate_order() const;
|
GateOrderMode get_gate_order() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr GRUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
|
||||||
using PrimGRUPtr = std::shared_ptr<GRU>;
|
using PrimGRUPtr = std::shared_ptr<GRU>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,12 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <set>
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/invert_permutation.h"
|
#include "ops/invert_permutation.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
#include "ops/op_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
|
@ -30,11 +30,8 @@ class InvertPermutation : public PrimitiveC {
|
||||||
InvertPermutation() : PrimitiveC(kNameInvertPermutation) {}
|
InvertPermutation() : PrimitiveC(kNameInvertPermutation) {}
|
||||||
~InvertPermutation() = default;
|
~InvertPermutation() = default;
|
||||||
MS_DECLARE_PARENT(InvertPermutation, PrimitiveC);
|
MS_DECLARE_PARENT(InvertPermutation, PrimitiveC);
|
||||||
void Init() {}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr InvertPermutationInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
|
||||||
using PrimInvertPermutationPtr = std::shared_ptr<InvertPermutation>;
|
using PrimInvertPermutationPtr = std::shared_ptr<InvertPermutation>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -31,7 +31,6 @@ class LinSpace : public PrimitiveC {
|
||||||
LinSpace() : PrimitiveC(kNameLinSpace) { InitIOName({"start", "stop", "num"}, {"output"}); }
|
LinSpace() : PrimitiveC(kNameLinSpace) { InitIOName({"start", "stop", "num"}, {"output"}); }
|
||||||
~LinSpace() = default;
|
~LinSpace() = default;
|
||||||
MS_DECLARE_PARENT(LinSpace, PrimitiveC);
|
MS_DECLARE_PARENT(LinSpace, PrimitiveC);
|
||||||
void Init() {}
|
|
||||||
};
|
};
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,95 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "ops/local_response_normalization.h"
|
|
||||||
#include <string>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <memory>
|
|
||||||
#include <set>
|
|
||||||
#include <vector>
|
|
||||||
#include "ops/op_utils.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
#include "abstract/primitive_infer_map.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
namespace {
|
|
||||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
|
||||||
auto x = input_args[0]->BuildShape();
|
|
||||||
auto shape_element = x->cast<abstract::ShapePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(shape_element);
|
|
||||||
return shape_element;
|
|
||||||
}
|
|
||||||
|
|
||||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
|
||||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim->name());
|
|
||||||
for (const auto &item : input_args) {
|
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
|
||||||
}
|
|
||||||
std::map<std::string, TypePtr> types;
|
|
||||||
types.emplace("x", input_args[0]->BuildType());
|
|
||||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void LocalResponseNormalization::Init(const int64_t depth_radius, const float bias, const float alpha,
|
|
||||||
const float beta) {
|
|
||||||
this->set_depth_radius(depth_radius);
|
|
||||||
this->set_bias(bias);
|
|
||||||
this->set_alpha(alpha);
|
|
||||||
this->set_beta(beta);
|
|
||||||
}
|
|
||||||
|
|
||||||
void LocalResponseNormalization::set_depth_radius(const int64_t depth_radius) {
|
|
||||||
this->AddAttr(kDepthRadius, MakeValue(depth_radius));
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t LocalResponseNormalization::get_depth_radius() const {
|
|
||||||
auto value_ptr = GetAttr(kDepthRadius);
|
|
||||||
return GetValue<int64_t>(value_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
void LocalResponseNormalization::set_bias(const float bias) { this->AddAttr(kBias, MakeValue(bias)); }
|
|
||||||
|
|
||||||
float LocalResponseNormalization::get_bias() const {
|
|
||||||
auto value_ptr = GetAttr(kBias);
|
|
||||||
return GetValue<float>(value_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
void LocalResponseNormalization::set_alpha(const float alpha) { this->AddAttr(kAlpha, MakeValue(alpha)); }
|
|
||||||
|
|
||||||
float LocalResponseNormalization::get_alpha() const {
|
|
||||||
auto value_ptr = GetAttr(kAlpha);
|
|
||||||
return GetValue<float>(value_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
void LocalResponseNormalization::set_beta(const float beta) { this->AddAttr(kBeta, MakeValue(beta)); }
|
|
||||||
|
|
||||||
float LocalResponseNormalization::get_beta() const {
|
|
||||||
auto value_ptr = GetAttr(kBeta);
|
|
||||||
return GetValue<float>(value_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
AbstractBasePtr LocalResponseNormalizationInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
|
||||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
|
||||||
InferShape(primitive, input_args)->shape());
|
|
||||||
}
|
|
||||||
REGISTER_PRIMITIVE_C(kNameLocalResponseNormalization, LocalResponseNormalization);
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,53 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_LOCAL_RESPONSE_NORMALIZATION_H_
|
|
||||||
#define MINDSPORE_CORE_OPS_LOCAL_RESPONSE_NORMALIZATION_H_
|
|
||||||
#include <map>
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <memory>
|
|
||||||
#include "ops/primitive_c.h"
|
|
||||||
#include "abstract/abstract_value.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
constexpr auto kNameLocalResponseNormalization = "LocalResponseNormalization";
|
|
||||||
class LocalResponseNormalization : public PrimitiveC {
|
|
||||||
public:
|
|
||||||
LocalResponseNormalization() : PrimitiveC(kNameLocalResponseNormalization) {}
|
|
||||||
~LocalResponseNormalization() = default;
|
|
||||||
MS_DECLARE_PARENT(LocalResponseNormalization, PrimitiveC);
|
|
||||||
void Init(const int64_t depth_radius, const float bias, const float alpha, const float beta);
|
|
||||||
void set_depth_radius(const int64_t depth_radius);
|
|
||||||
void set_bias(const float bias);
|
|
||||||
void set_alpha(const float alpha);
|
|
||||||
void set_beta(const float beta);
|
|
||||||
|
|
||||||
int64_t get_depth_radius() const;
|
|
||||||
float get_bias() const;
|
|
||||||
float get_alpha() const;
|
|
||||||
float get_beta() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
AbstractBasePtr LocalResponseNormalizationInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
|
||||||
using PrimLocalResponseNormalizationPtr = std::shared_ptr<LocalResponseNormalization>;
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // MINDSPORE_CORE_OPS_LOCAL_RESPONSE_NORMALIZATION_H_
|
|
|
@ -1,35 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "ops/loop.h"
|
|
||||||
#include "ops/op_utils.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
void Loop::Init(const int64_t sub_graph_index) { this->set_sub_graph_index(sub_graph_index); }
|
|
||||||
|
|
||||||
void Loop::set_sub_graph_index(const int64_t sub_graph_index) {
|
|
||||||
this->AddAttr(kSubGraphIndex, MakeValue(sub_graph_index));
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t Loop::get_sub_graph_index() const {
|
|
||||||
auto value_ptr = this->GetAttr(kSubGraphIndex);
|
|
||||||
return GetValue<int64_t>(value_ptr);
|
|
||||||
}
|
|
||||||
REGISTER_PRIMITIVE_C(kNameLoop, Loop);
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,42 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_LOOP_H_
|
|
||||||
#define MINDSPORE_CORE_OPS_LOOP_H_
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#include "ops/primitive_c.h"
|
|
||||||
#include "abstract/abstract_value.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
constexpr auto kNameLoop = "Loop";
|
|
||||||
class Loop : public PrimitiveC {
|
|
||||||
public:
|
|
||||||
Loop() : PrimitiveC(kNameLoop) {}
|
|
||||||
~Loop() = default;
|
|
||||||
MS_DECLARE_PARENT(Loop, PrimitiveC);
|
|
||||||
void Init(const int64_t sub_graph_index);
|
|
||||||
void set_sub_graph_index(const int64_t sub_graph_index);
|
|
||||||
int64_t get_sub_graph_index() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
using PrimLoopPtr = std::shared_ptr<Loop>;
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // MINDSPORE_CORE_OPS_LOOP_H_
|
|
|
@ -1,31 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "ops/make_tuple.h"
|
|
||||||
#include <string>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <memory>
|
|
||||||
#include <set>
|
|
||||||
#include <vector>
|
|
||||||
#include "ops/op_utils.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
#include "abstract/primitive_infer_map.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
REGISTER_PRIMITIVE_C(kNameMakeTuple, MakeTuple);
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,40 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_MAKE_TUPLE_H_
|
|
||||||
#define MINDSPORE_CORE_OPS_MAKE_TUPLE_H_
|
|
||||||
#include <map>
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <memory>
|
|
||||||
#include "ops/primitive_c.h"
|
|
||||||
#include "abstract/abstract_value.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
constexpr auto kNameMakeTuple = "MakeTuple";
|
|
||||||
class MakeTuple : public PrimitiveC {
|
|
||||||
public:
|
|
||||||
MakeTuple() : PrimitiveC(kNameMakeTuple) {}
|
|
||||||
~MakeTuple() = default;
|
|
||||||
MS_DECLARE_PARENT(MakeTuple, PrimitiveC);
|
|
||||||
void Init() {}
|
|
||||||
};
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // MINDSPORE_CORE_OPS_MAKE_TUPLE_H_
|
|
|
@ -1,74 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "ops/matrix_diag.h"
|
|
||||||
#include <string>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <memory>
|
|
||||||
#include <set>
|
|
||||||
#include <vector>
|
|
||||||
#include "ops/op_utils.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
#include "abstract/primitive_infer_map.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
namespace {
|
|
||||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
|
||||||
auto prim_name = primitive->name();
|
|
||||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
|
||||||
auto assist_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
|
||||||
|
|
||||||
CheckAndConvertUtils::CheckInteger("assist rank", (int64_t)assist_shape.size(), kGreaterEqual, 2, prim_name);
|
|
||||||
CheckAndConvertUtils::Check("x_shape rank", (int64_t)x_shape.size() + 1, kLessEqual, "assist rank",
|
|
||||||
(int64_t)assist_shape.size(), prim_name);
|
|
||||||
CheckAndConvertUtils::Check("assist's penultimate dimension", assist_shape[(int64_t)assist_shape.size() - 2], kEqual,
|
|
||||||
"assist's last dimension", assist_shape[(int64_t)assist_shape.size() - 1], prim_name);
|
|
||||||
|
|
||||||
int64_t x_end_dim = x_shape.size() - 1;
|
|
||||||
int64_t assist_end_dim = assist_shape.size() - 1;
|
|
||||||
while (x_end_dim >= 0) {
|
|
||||||
if (x_shape[x_end_dim] != 1) {
|
|
||||||
CheckAndConvertUtils::Check("reverse x dim", x_shape[x_end_dim], kEqual, "reverse assist dim",
|
|
||||||
assist_shape[assist_end_dim - 1], prim_name);
|
|
||||||
}
|
|
||||||
x_end_dim--;
|
|
||||||
assist_end_dim--;
|
|
||||||
}
|
|
||||||
return std::make_shared<abstract::Shape>(assist_shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
|
||||||
for (const auto &item : input_args) {
|
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
|
||||||
}
|
|
||||||
const std::set<TypePtr> valid_types = {kInt8, kInt32, kUInt8, kFloat16, kFloat32};
|
|
||||||
std::map<std::string, TypePtr> types;
|
|
||||||
types.emplace("x", input_args[0]->BuildType());
|
|
||||||
types.emplace("assist", input_args[1]->BuildType());
|
|
||||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
AbstractBasePtr MatrixDiagInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
|
||||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
|
||||||
InferShape(primitive, input_args)->shape());
|
|
||||||
}
|
|
||||||
REGISTER_PRIMITIVE_C(kNameMatrixDiag, MatrixDiag);
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,43 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_MATRIX_DIAG_H_
|
|
||||||
#define MINDSPORE_CORE_OPS_MATRIX_DIAG_H_
|
|
||||||
#include <map>
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <memory>
|
|
||||||
#include "ops/primitive_c.h"
|
|
||||||
#include "abstract/abstract_value.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
constexpr auto kNameMatrixDiag = "MatrixDiag";
|
|
||||||
class MatrixDiag : public PrimitiveC {
|
|
||||||
public:
|
|
||||||
MatrixDiag() : PrimitiveC(kNameMatrixDiag) {}
|
|
||||||
~MatrixDiag() = default;
|
|
||||||
MS_DECLARE_PARENT(MatrixDiag, PrimitiveC);
|
|
||||||
void Init() {}
|
|
||||||
};
|
|
||||||
AbstractBasePtr MatrixDiagInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
|
||||||
using PrimMatrixDiagPtr = std::shared_ptr<MatrixDiag>;
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // MINDSPORE_CORE_OPS_MATRIX_DIAG_H_
|
|
|
@ -1,23 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "ops/mul_fold.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
REGISTER_PRIMITIVE_C(kNameMulFold, MulFold);
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,44 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_MUL_FOLD_H_
|
|
||||||
#define MINDSPORE_CORE_OPS_MUL_FOLD_H_
|
|
||||||
|
|
||||||
#include <map>
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <memory>
|
|
||||||
#include <algorithm>
|
|
||||||
#include "ops/op_utils.h"
|
|
||||||
#include "ops/primitive_c.h"
|
|
||||||
#include "abstract/primitive_infer_map.h"
|
|
||||||
#include "abstract/abstract_value.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
constexpr auto kNameMulFold = "MulFold";
|
|
||||||
class MulFold : public PrimitiveC {
|
|
||||||
public:
|
|
||||||
MulFold() : PrimitiveC(kNameMulFold) {}
|
|
||||||
~MulFold() = default;
|
|
||||||
MS_DECLARE_PARENT(MulFold, PrimitiveC);
|
|
||||||
void Init() {}
|
|
||||||
};
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // MINDSPORE_CORE_OPS_MUL_FOLD_H_
|
|
|
@ -1,31 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "ops/net_output.h"
|
|
||||||
#include <string>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <memory>
|
|
||||||
#include <set>
|
|
||||||
#include <vector>
|
|
||||||
#include "ops/op_utils.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
#include "abstract/primitive_infer_map.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
REGISTER_PRIMITIVE_C(kNameNetOutput, NetOutput);
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,40 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_NET_OUTPUT_H_
|
|
||||||
#define MINDSPORE_CORE_OPS_NET_OUTPUT_H_
|
|
||||||
#include <map>
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <memory>
|
|
||||||
#include "ops/primitive_c.h"
|
|
||||||
#include "abstract/abstract_value.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
constexpr auto kNameNetOutput = "NetOutput";
|
|
||||||
class NetOutput : public PrimitiveC {
|
|
||||||
public:
|
|
||||||
NetOutput() : PrimitiveC(kNameNetOutput) {}
|
|
||||||
~NetOutput() = default;
|
|
||||||
MS_DECLARE_PARENT(NetOutput, PrimitiveC);
|
|
||||||
void Init() {}
|
|
||||||
};
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // MINDSPORE_CORE_OPS_NET_OUTPUT_H_
|
|
|
@ -14,12 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <set>
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/non_zero.h"
|
#include "ops/non_zero.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
#include "ops/op_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_NON_ZERO_H_
|
#ifndef MINDSPORE_CORE_OPS_NON_ZERO_H_
|
||||||
#define MINDSPORE_CORE_OPS_NON_ZERO_H_
|
#define MINDSPORE_CORE_OPS_NON_ZERO_H_
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/primitive_c.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "abstract/abstract_value.h"
|
||||||
|
@ -30,11 +29,7 @@ class NonZero : public PrimitiveC {
|
||||||
NonZero() : PrimitiveC(kNameNonZero) {}
|
NonZero() : PrimitiveC(kNameNonZero) {}
|
||||||
~NonZero() = default;
|
~NonZero() = default;
|
||||||
MS_DECLARE_PARENT(NonZero, PrimitiveC);
|
MS_DECLARE_PARENT(NonZero, PrimitiveC);
|
||||||
void Init() {}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr NonZeroInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
|
||||||
using PrimNonZeroPtr = std::shared_ptr<NonZero>;
|
using PrimNonZeroPtr = std::shared_ptr<NonZero>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,25 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "ops/onnx_int8_dequantize.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
#include "abstract/primitive_infer_map.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
REGISTER_PRIMITIVE_C(kNameOnnxInt8Dequantize, OnnxInt8Dequantize);
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,41 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_ONNX_INT8_DEQUANTIZE_H_
|
|
||||||
#define MINDSPORE_CORE_OPS_ONNX_INT8_DEQUANTIZE_H_
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#include "ops/primitive_c.h"
|
|
||||||
#include "abstract/abstract_value.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
constexpr auto kNameOnnxInt8Dequantize = "OnnxInt8Dequantize";
|
|
||||||
class OnnxInt8Dequantize : public PrimitiveC {
|
|
||||||
public:
|
|
||||||
OnnxInt8Dequantize() : PrimitiveC(kNameOnnxInt8Dequantize) {}
|
|
||||||
~OnnxInt8Dequantize() = default;
|
|
||||||
MS_DECLARE_PARENT(OnnxInt8Dequantize, PrimitiveC);
|
|
||||||
void Init() {}
|
|
||||||
};
|
|
||||||
|
|
||||||
using PrimOnnxInt8DequantizePtr = std::shared_ptr<OnnxInt8Dequantize>;
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // MINDSPORE_CORE_OPS_ONNX_INT8_DEQUANTIZE_H_
|
|
|
@ -1,31 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "ops/onnx_int8_quantize.h"
|
|
||||||
#include <string>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <memory>
|
|
||||||
#include <set>
|
|
||||||
#include <vector>
|
|
||||||
#include "ops/op_utils.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
#include "abstract/primitive_infer_map.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
REGISTER_PRIMITIVE_C(kNameOnnxInt8Quantize, OnnxInt8Quantize);
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,40 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_ONNX_INT8_QUANTIZE_H_
|
|
||||||
#define MINDSPORE_CORE_OPS_ONNX_INT8_QUANTIZE_H_
|
|
||||||
#include <map>
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <memory>
|
|
||||||
#include "ops/primitive_c.h"
|
|
||||||
#include "abstract/abstract_value.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
constexpr auto kNameOnnxInt8Quantize = "OnnxInt8Quantize";
|
|
||||||
class OnnxInt8Quantize : public PrimitiveC {
|
|
||||||
public:
|
|
||||||
OnnxInt8Quantize() : PrimitiveC(kNameOnnxInt8Quantize) {}
|
|
||||||
~OnnxInt8Quantize() = default;
|
|
||||||
MS_DECLARE_PARENT(OnnxInt8Quantize, PrimitiveC);
|
|
||||||
void Init() {}
|
|
||||||
};
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // MINDSPORE_CORE_OPS_ONNX_INT8_QUANTIZE_H_
|
|
|
@ -136,7 +136,7 @@ constexpr auto kOutChannel = "out_channel";
|
||||||
constexpr auto kOutMaxValue = "out_max_value";
|
constexpr auto kOutMaxValue = "out_max_value";
|
||||||
constexpr auto kOutputChannel = "output_channel";
|
constexpr auto kOutputChannel = "output_channel";
|
||||||
constexpr auto kOutputNum = "output_num";
|
constexpr auto kOutputNum = "output_num";
|
||||||
constexpr auto koutputPaddings = "output_paddings";
|
constexpr auto kOutputPaddings = "output_paddings";
|
||||||
constexpr auto kOutputType = "output_type";
|
constexpr auto kOutputType = "output_type";
|
||||||
constexpr auto kOutQuantized = "out_quantized";
|
constexpr auto kOutQuantized = "out_quantized";
|
||||||
constexpr auto kP = "p";
|
constexpr auto kP = "p";
|
||||||
|
|
|
@ -1,38 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "ops/permute.h"
|
|
||||||
#include <string>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <memory>
|
|
||||||
#include <set>
|
|
||||||
#include <vector>
|
|
||||||
#include "ops/op_utils.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
#include "abstract/primitive_infer_map.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
void Permute::set_order(const std::vector<int64_t> &order) { this->AddAttr(kOrder, MakeValue(order)); }
|
|
||||||
|
|
||||||
std::vector<int64_t> Permute::get_order() const {
|
|
||||||
auto value_ptr = GetAttr(kOrder);
|
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Permute::Init(const std::vector<int64_t> &order) { this->set_order(order); }
|
|
||||||
REGISTER_PRIMITIVE_C(kNamePermute, Permute);
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,43 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_PERMUTE_H_
|
|
||||||
#define MINDSPORE_CORE_OPS_PERMUTE_H_
|
|
||||||
#include <map>
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <memory>
|
|
||||||
#include "ops/primitive_c.h"
|
|
||||||
#include "abstract/abstract_value.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
constexpr auto kNamePermute = "Permute";
|
|
||||||
class Permute : public PrimitiveC {
|
|
||||||
public:
|
|
||||||
Permute() : PrimitiveC(kNamePermute) {}
|
|
||||||
~Permute() = default;
|
|
||||||
MS_DECLARE_PARENT(Permute, PrimitiveC);
|
|
||||||
|
|
||||||
void Init(const std::vector<int64_t> &order);
|
|
||||||
void set_order(const std::vector<int64_t> &order);
|
|
||||||
std::vector<int64_t> get_order() const;
|
|
||||||
};
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // MINDSPORE_CORE_OPS_PERMUTE_H_
|
|
|
@ -16,7 +16,6 @@
|
||||||
#include "ops/random_standard_normal.h"
|
#include "ops/random_standard_normal.h"
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
|
|
||||||
|
@ -27,9 +26,9 @@ void RandomStandardNormal::Init(const int64_t seed, const int64_t seed2) {
|
||||||
this->set_seed2(seed2);
|
this->set_seed2(seed2);
|
||||||
}
|
}
|
||||||
|
|
||||||
void RandomStandardNormal::set_seed(const int64_t seed) { this->AddAttr(kSeed, MakeValue(seed)); }
|
void RandomStandardNormal::set_seed(int64_t seed) { this->AddAttr(kSeed, MakeValue(seed)); }
|
||||||
|
|
||||||
void RandomStandardNormal::set_seed2(const int64_t seed2) { this->AddAttr(kSeed2, MakeValue(seed2)); }
|
void RandomStandardNormal::set_seed2(int64_t seed2) { this->AddAttr(kSeed2, MakeValue(seed2)); }
|
||||||
|
|
||||||
int64_t RandomStandardNormal::get_seed() const {
|
int64_t RandomStandardNormal::get_seed() const {
|
||||||
auto value_ptr = GetAttr(kSeed);
|
auto value_ptr = GetAttr(kSeed);
|
||||||
|
|
|
@ -32,17 +32,13 @@ class RandomStandardNormal : public PrimitiveC {
|
||||||
RandomStandardNormal() : PrimitiveC(kNameRandomStandardNormal) {}
|
RandomStandardNormal() : PrimitiveC(kNameRandomStandardNormal) {}
|
||||||
~RandomStandardNormal() = default;
|
~RandomStandardNormal() = default;
|
||||||
MS_DECLARE_PARENT(RandomStandardNormal, PrimitiveC);
|
MS_DECLARE_PARENT(RandomStandardNormal, PrimitiveC);
|
||||||
void Init(const int64_t seed, const int64_t seed2);
|
void Init(int64_t seed, int64_t seed2);
|
||||||
|
|
||||||
void set_seed(const int64_t seed);
|
|
||||||
void set_seed2(const int64_t seed2);
|
|
||||||
|
|
||||||
|
void set_seed(int64_t seed);
|
||||||
|
void set_seed2(int64_t seed2);
|
||||||
int64_t get_seed() const;
|
int64_t get_seed() const;
|
||||||
int64_t get_seed2() const;
|
int64_t get_seed2() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr RandomStandardNormalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
|
||||||
using PrimRandomStandardNormalPtr = std::shared_ptr<RandomStandardNormal>;
|
using PrimRandomStandardNormalPtr = std::shared_ptr<RandomStandardNormal>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,25 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "ops/return.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
#include "abstract/primitive_infer_map.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
REGISTER_PRIMITIVE_C(kNameReturn, Return);
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,40 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_RETURN_H_
|
|
||||||
#define MINDSPORE_CORE_OPS_RETURN_H_
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#include "ops/primitive_c.h"
|
|
||||||
#include "abstract/abstract_value.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
constexpr auto kNameReturn = "Return";
|
|
||||||
class Return : public PrimitiveC {
|
|
||||||
public:
|
|
||||||
Return() : PrimitiveC(kNameReturn) {}
|
|
||||||
~Return() = default;
|
|
||||||
MS_DECLARE_PARENT(Return, PrimitiveC);
|
|
||||||
void Init() {}
|
|
||||||
};
|
|
||||||
|
|
||||||
using PrimReturnPtr = std::shared_ptr<Return>;
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // MINDSPORE_CORE_OPS_RETURN_H_
|
|
|
@ -14,12 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <set>
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/size.h"
|
#include "ops/size.h"
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
#include "ops/op_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
|
@ -30,11 +30,8 @@ class Size : public PrimitiveC {
|
||||||
Size() : PrimitiveC(kNameSize) {}
|
Size() : PrimitiveC(kNameSize) {}
|
||||||
~Size() = default;
|
~Size() = default;
|
||||||
MS_DECLARE_PARENT(Size, PrimitiveC);
|
MS_DECLARE_PARENT(Size, PrimitiveC);
|
||||||
void Init() {}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr SizeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
|
||||||
using PrimSizePtr = std::shared_ptr<Size>;
|
using PrimSizePtr = std::shared_ptr<Size>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,23 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "ops/tuple_get_item.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
REGISTER_PRIMITIVE_C(kNameTupleGetItem, TupleGetItem);
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,40 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_TUPLE_GET_ITEM_H_
|
|
||||||
#define MINDSPORE_CORE_OPS_TUPLE_GET_ITEM_H_
|
|
||||||
#include <map>
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <memory>
|
|
||||||
#include "ops/primitive_c.h"
|
|
||||||
#include "abstract/abstract_value.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
constexpr auto kNameTupleGetItem = "TupleGetItem";
|
|
||||||
class TupleGetItem : public PrimitiveC {
|
|
||||||
public:
|
|
||||||
TupleGetItem() : PrimitiveC(kNameTupleGetItem) {}
|
|
||||||
~TupleGetItem() = default;
|
|
||||||
MS_DECLARE_PARENT(TupleGetItem, PrimitiveC);
|
|
||||||
void Init() {}
|
|
||||||
};
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // MINDSPORE_CORE_OPS_TUPLE_GET_ITEM_H_
|
|
|
@ -13,23 +13,23 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "ops/uniform_real.h"
|
#include "ops/uniform_real.h"
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
void UniformReal::Init(const int64_t seed, const int64_t seed2) {
|
void UniformReal::Init(int64_t seed, int64_t seed2) {
|
||||||
this->set_seed(seed);
|
this->set_seed(seed);
|
||||||
this->set_seed2(seed2);
|
this->set_seed2(seed2);
|
||||||
}
|
}
|
||||||
|
|
||||||
void UniformReal::set_seed(const int64_t seed) { this->AddAttr(kSeed, MakeValue(seed)); }
|
void UniformReal::set_seed(int64_t seed) { this->AddAttr(kSeed, MakeValue(seed)); }
|
||||||
|
|
||||||
void UniformReal::set_seed2(const int64_t seed2) { this->AddAttr(kSeed2, MakeValue(seed2)); }
|
void UniformReal::set_seed2(int64_t seed2) { this->AddAttr(kSeed2, MakeValue(seed2)); }
|
||||||
|
|
||||||
int64_t UniformReal::get_seed() const {
|
int64_t UniformReal::get_seed() const {
|
||||||
auto value_ptr = GetAttr(kSeed);
|
auto value_ptr = GetAttr(kSeed);
|
||||||
|
|
|
@ -32,11 +32,10 @@ class UniformReal : public PrimitiveC {
|
||||||
UniformReal() : PrimitiveC(kNameUniformReal) {}
|
UniformReal() : PrimitiveC(kNameUniformReal) {}
|
||||||
~UniformReal() = default;
|
~UniformReal() = default;
|
||||||
MS_DECLARE_PARENT(UniformReal, PrimitiveC);
|
MS_DECLARE_PARENT(UniformReal, PrimitiveC);
|
||||||
void Init(const int64_t seed, const int64_t seed2);
|
void Init(int64_t seed, int64_t seed2);
|
||||||
|
|
||||||
void set_seed(const int64_t seed);
|
|
||||||
void set_seed2(const int64_t seed2);
|
|
||||||
|
|
||||||
|
void set_seed(int64_t seed);
|
||||||
|
void set_seed2(int64_t seed2);
|
||||||
int64_t get_seed() const;
|
int64_t get_seed() const;
|
||||||
int64_t get_seed2() const;
|
int64_t get_seed2() const;
|
||||||
};
|
};
|
||||||
|
|
|
@ -39,7 +39,6 @@
|
||||||
#include "ops/batch_to_space_nd.h"
|
#include "ops/batch_to_space_nd.h"
|
||||||
#include "ops/bias_add.h"
|
#include "ops/bias_add.h"
|
||||||
#include "ops/binary_cross_entropy.h"
|
#include "ops/binary_cross_entropy.h"
|
||||||
#include "ops/black_box.h"
|
|
||||||
#include "ops/broadcast_to.h"
|
#include "ops/broadcast_to.h"
|
||||||
#include "ops/broadcast.h"
|
#include "ops/broadcast.h"
|
||||||
#include "ops/cast.h"
|
#include "ops/cast.h"
|
||||||
|
@ -50,7 +49,6 @@
|
||||||
#include "ops/custom_predict.h"
|
#include "ops/custom_predict.h"
|
||||||
#include "ops/custom_extract_features.h"
|
#include "ops/custom_extract_features.h"
|
||||||
#include "ops/concat.h"
|
#include "ops/concat.h"
|
||||||
#include "ops/constant.h"
|
|
||||||
#include "ops/constant_of_shape.h"
|
#include "ops/constant_of_shape.h"
|
||||||
#include "ops/control_depend.h"
|
#include "ops/control_depend.h"
|
||||||
#include "ops/cos.h"
|
#include "ops/cos.h"
|
||||||
|
@ -92,14 +90,11 @@
|
||||||
#include "ops/logical_not.h"
|
#include "ops/logical_not.h"
|
||||||
#include "ops/logical_or.h"
|
#include "ops/logical_or.h"
|
||||||
#include "ops/logical_xor.h"
|
#include "ops/logical_xor.h"
|
||||||
#include "ops/loop.h"
|
|
||||||
#include "ops/lp_normalization.h"
|
#include "ops/lp_normalization.h"
|
||||||
#include "ops/lrn.h"
|
#include "ops/lrn.h"
|
||||||
#include "ops/lsh_projection.h"
|
#include "ops/lsh_projection.h"
|
||||||
#include "ops/lstm.h"
|
#include "ops/lstm.h"
|
||||||
#include "ops/make_tuple.h"
|
|
||||||
#include "ops/mat_mul.h"
|
#include "ops/mat_mul.h"
|
||||||
#include "ops/matrix_diag.h"
|
|
||||||
#include "ops/max_pool.h"
|
#include "ops/max_pool.h"
|
||||||
#include "ops/maximum.h"
|
#include "ops/maximum.h"
|
||||||
#include "ops/merge.h"
|
#include "ops/merge.h"
|
||||||
|
@ -108,13 +103,11 @@
|
||||||
#include "ops/mod.h"
|
#include "ops/mod.h"
|
||||||
#include "ops/mul.h"
|
#include "ops/mul.h"
|
||||||
#include "ops/neg.h"
|
#include "ops/neg.h"
|
||||||
#include "ops/net_output.h"
|
|
||||||
#include "ops/non_max_suppression.h"
|
#include "ops/non_max_suppression.h"
|
||||||
#include "ops/not_equal.h"
|
#include "ops/not_equal.h"
|
||||||
#include "ops/one_hot.h"
|
#include "ops/one_hot.h"
|
||||||
#include "ops/ones_like.h"
|
#include "ops/ones_like.h"
|
||||||
#include "ops/pad.h"
|
#include "ops/pad.h"
|
||||||
#include "ops/permute.h"
|
|
||||||
#include "ops/prelu.h"
|
#include "ops/prelu.h"
|
||||||
#include "ops/prior_box.h"
|
#include "ops/prior_box.h"
|
||||||
#include "ops/proposal.h"
|
#include "ops/proposal.h"
|
||||||
|
@ -127,7 +120,6 @@
|
||||||
#include "ops/relu6.h"
|
#include "ops/relu6.h"
|
||||||
#include "ops/reshape.h"
|
#include "ops/reshape.h"
|
||||||
#include "ops/resize.h"
|
#include "ops/resize.h"
|
||||||
#include "ops/return.h"
|
|
||||||
#include "ops/reverse_sequence.h"
|
#include "ops/reverse_sequence.h"
|
||||||
#include "ops/reverse_v2.h"
|
#include "ops/reverse_v2.h"
|
||||||
#include "ops/rfft.h"
|
#include "ops/rfft.h"
|
||||||
|
@ -169,7 +161,6 @@
|
||||||
#include "ops/tensor_list_stack.h"
|
#include "ops/tensor_list_stack.h"
|
||||||
#include "ops/tile.h"
|
#include "ops/tile.h"
|
||||||
#include "ops/transpose.h"
|
#include "ops/transpose.h"
|
||||||
#include "ops/tuple_get_item.h"
|
|
||||||
#include "ops/unique.h"
|
#include "ops/unique.h"
|
||||||
#include "ops/unstack.h"
|
#include "ops/unstack.h"
|
||||||
#include "ops/unsqueeze.h"
|
#include "ops/unsqueeze.h"
|
||||||
|
@ -291,7 +282,6 @@ FUNC_MSOP2SCHEMAOP_DECLARE(Ceil);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(Clip);
|
FUNC_MSOP2SCHEMAOP_DECLARE(Clip);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(Concat);
|
FUNC_MSOP2SCHEMAOP_DECLARE(Concat);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(ControlDepend);
|
FUNC_MSOP2SCHEMAOP_DECLARE(ControlDepend);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(Constant);
|
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(ConstantOfShape);
|
FUNC_MSOP2SCHEMAOP_DECLARE(ConstantOfShape);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(Conv2DBackpropFilterFusion);
|
FUNC_MSOP2SCHEMAOP_DECLARE(Conv2DBackpropFilterFusion);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(Conv2DBackpropInputFusion);
|
FUNC_MSOP2SCHEMAOP_DECLARE(Conv2DBackpropInputFusion);
|
||||||
|
@ -350,7 +340,6 @@ FUNC_MSOP2SCHEMAOP_DECLARE(LRN);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(LshProjection);
|
FUNC_MSOP2SCHEMAOP_DECLARE(LshProjection);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(LSTM);
|
FUNC_MSOP2SCHEMAOP_DECLARE(LSTM);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(L2NormalizeFusion);
|
FUNC_MSOP2SCHEMAOP_DECLARE(L2NormalizeFusion);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(MakeTuple);
|
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(MatMul);
|
FUNC_MSOP2SCHEMAOP_DECLARE(MatMul);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(Maximum);
|
FUNC_MSOP2SCHEMAOP_DECLARE(Maximum);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(MaximumGrad);
|
FUNC_MSOP2SCHEMAOP_DECLARE(MaximumGrad);
|
||||||
|
@ -384,7 +373,6 @@ FUNC_MSOP2SCHEMAOP_DECLARE(Reciprocal);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(ReduceFusion);
|
FUNC_MSOP2SCHEMAOP_DECLARE(ReduceFusion);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(Reshape);
|
FUNC_MSOP2SCHEMAOP_DECLARE(Reshape);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(Resize);
|
FUNC_MSOP2SCHEMAOP_DECLARE(Resize);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(Return);
|
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(ReverseSequence);
|
FUNC_MSOP2SCHEMAOP_DECLARE(ReverseSequence);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(ReverseV2);
|
FUNC_MSOP2SCHEMAOP_DECLARE(ReverseV2);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(Rfft);
|
FUNC_MSOP2SCHEMAOP_DECLARE(Rfft);
|
||||||
|
@ -432,7 +420,6 @@ FUNC_MSOP2SCHEMAOP_DECLARE(TensorListStack);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(TileFusion);
|
FUNC_MSOP2SCHEMAOP_DECLARE(TileFusion);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(TopKFusion);
|
FUNC_MSOP2SCHEMAOP_DECLARE(TopKFusion);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(Transpose);
|
FUNC_MSOP2SCHEMAOP_DECLARE(Transpose);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(TupleGetItem);
|
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(Unique);
|
FUNC_MSOP2SCHEMAOP_DECLARE(Unique);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(UnsortedSegmentSum);
|
FUNC_MSOP2SCHEMAOP_DECLARE(UnsortedSegmentSum);
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(Unsqueeze);
|
FUNC_MSOP2SCHEMAOP_DECLARE(Unsqueeze);
|
||||||
|
|
|
@ -1,61 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "src/ops/populate/populate_register.h"
|
|
||||||
#include "nnacl/conv_parameter.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace lite {
|
|
||||||
/*
|
|
||||||
OpParameter *PopulateDeconvDwParameter(const mindspore::lite::PrimitiveC *primitive) {
|
|
||||||
ConvParameter *conv_param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
|
|
||||||
if (conv_param == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "malloc ConvParameter failed.";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
memset(conv_param, 0, sizeof(ConvParameter));
|
|
||||||
conv_param->op_parameter_.type_ = primitive->Type();
|
|
||||||
auto conv_primitive =
|
|
||||||
reinterpret_cast<mindspore::lite::DeDepthwiseConv2D *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
|
||||||
conv_param->kernel_h_ = conv_primitive->GetKernelH();
|
|
||||||
conv_param->kernel_w_ = conv_primitive->GetKernelW();
|
|
||||||
conv_param->stride_h_ = conv_primitive->GetStrideH();
|
|
||||||
conv_param->stride_w_ = conv_primitive->GetStrideW();
|
|
||||||
|
|
||||||
auto deconvdw_lite_primitive = (mindspore::lite::DeDepthwiseConv2D *)primitive;
|
|
||||||
conv_param->pad_u_ = deconvdw_lite_primitive->PadUp();
|
|
||||||
conv_param->pad_d_ = deconvdw_lite_primitive->PadDown();
|
|
||||||
conv_param->pad_l_ = deconvdw_lite_primitive->PadLeft();
|
|
||||||
conv_param->pad_r_ = deconvdw_lite_primitive->PadRight();
|
|
||||||
conv_param->dilation_h_ = conv_primitive->GetDilateH();
|
|
||||||
conv_param->dilation_w_ = conv_primitive->GetDilateW();
|
|
||||||
auto act_type = conv_primitive->GetActivationType();
|
|
||||||
switch (act_type) {
|
|
||||||
case schema::ActivationType_RELU:
|
|
||||||
conv_param->act_type_ = ActType_Relu;
|
|
||||||
break;
|
|
||||||
case schema::ActivationType_RELU6:
|
|
||||||
conv_param->act_type_ = ActType_Relu6;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
conv_param->act_type_ = ActType_No;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
return reinterpret_cast<OpParameter *>(conv_param);
|
|
||||||
}
|
|
||||||
|
|
||||||
*/
|
|
||||||
} // namespace lite
|
|
||||||
} // namespace mindspore
|
|
|
@ -29,9 +29,8 @@
|
||||||
#include "mindspore/core/ops/op_utils.h"
|
#include "mindspore/core/ops/op_utils.h"
|
||||||
#include "ops/fusion/partial_fusion.h"
|
#include "ops/fusion/partial_fusion.h"
|
||||||
#include "ops/depend.h"
|
#include "ops/depend.h"
|
||||||
#include "ops/make_tuple.h"
|
#include "tools/converter/ops/ops_def.h"
|
||||||
#include "ops/quant_dtype_cast.h"
|
#include "ops/quant_dtype_cast.h"
|
||||||
#include "ops/tuple_get_item.h"
|
|
||||||
#include "tools/converter/quant_param_holder.h"
|
#include "tools/converter/quant_param_holder.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
#include "tools/converter/quantizer/bitpacking.h"
|
#include "tools/converter/quantizer/bitpacking.h"
|
||||||
|
@ -313,8 +312,8 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
|
||||||
}
|
}
|
||||||
|
|
||||||
RemoveIfDepend(cnode);
|
RemoveIfDepend(cnode);
|
||||||
if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameTupleGetItem ||
|
if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::lite::kNameTupleGetItem ||
|
||||||
prim->name() == mindspore::ops::kNameMakeTuple) {
|
prim->name() == mindspore::lite::kNameMakeTuple) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (prim->name() == "make_tuple") {
|
if (prim->name() == "make_tuple") {
|
||||||
|
@ -329,7 +328,7 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) {
|
if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) {
|
||||||
node->name = mindspore::ops::kNameReturn;
|
node->name = mindspore::lite::kNameReturn;
|
||||||
ret = SetGraphoutputIndex(cnode, subgraph_index, meta_graphT, node.get());
|
ret = SetGraphoutputIndex(cnode, subgraph_index, meta_graphT, node.get());
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "SetOpOutputN failed";
|
MS_LOG(ERROR) << "SetOpOutputN failed";
|
||||||
|
|
|
@ -42,6 +42,10 @@ ADD_CONVERTER_ONLY_OP(TensorArrayScatterV3);
|
||||||
ADD_CONVERTER_ONLY_OP(TensorArraySizeV3);
|
ADD_CONVERTER_ONLY_OP(TensorArraySizeV3);
|
||||||
ADD_CONVERTER_ONLY_OP(TensorArrayV3);
|
ADD_CONVERTER_ONLY_OP(TensorArrayV3);
|
||||||
ADD_CONVERTER_ONLY_OP(TensorArrayWriteV3);
|
ADD_CONVERTER_ONLY_OP(TensorArrayWriteV3);
|
||||||
|
ADD_CONVERTER_ONLY_OP(Constant);
|
||||||
|
ADD_CONVERTER_ONLY_OP(MakeTuple);
|
||||||
|
ADD_CONVERTER_ONLY_OP(TupleGetItem);
|
||||||
|
ADD_CONVERTER_ONLY_OP(Return);
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -23,9 +23,7 @@
|
||||||
#include "tools/common/graph_util.h"
|
#include "tools/common/graph_util.h"
|
||||||
#include "tools/common/protobuf_utils.h"
|
#include "tools/common/protobuf_utils.h"
|
||||||
#include "tools/common/tensor_util.h"
|
#include "tools/common/tensor_util.h"
|
||||||
#include "ops/return.h"
|
#include "tools/converter/ops/ops_def.h"
|
||||||
#include "ops/make_tuple.h"
|
|
||||||
#include "ops/tuple_get_item.h"
|
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
#include "tools/converter/converter_flags.h"
|
#include "tools/converter/converter_flags.h"
|
||||||
|
|
||||||
|
@ -264,7 +262,7 @@ STATUS CaffeModelParser::ConvertGraphOutputs() {
|
||||||
caffeInspector.InspectModel(caffe_model_);
|
caffeInspector.InspectModel(caffe_model_);
|
||||||
if (caffeInspector.GetGraphOutput().size() > 1) {
|
if (caffeInspector.GetGraphOutput().size() > 1) {
|
||||||
std::vector<AnfNodePtr> make_tuple_inputs;
|
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||||
auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
|
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>();
|
||||||
if (make_tuple_prim_ptr == nullptr) {
|
if (make_tuple_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new MakeTuple failed";
|
MS_LOG(ERROR) << "new MakeTuple failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -283,7 +281,7 @@ STATUS CaffeModelParser::ConvertGraphOutputs() {
|
||||||
make_tuple_cnode->set_fullname_with_scope("return tuple");
|
make_tuple_cnode->set_fullname_with_scope("return tuple");
|
||||||
|
|
||||||
std::vector<AnfNodePtr> op_inputs;
|
std::vector<AnfNodePtr> op_inputs;
|
||||||
auto return_prim_ptr = std::make_shared<ops::Return>();
|
auto return_prim_ptr = std::make_shared<lite::Return>();
|
||||||
if (return_prim_ptr == nullptr) {
|
if (return_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new Return failed";
|
MS_LOG(ERROR) << "new Return failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -295,7 +293,7 @@ STATUS CaffeModelParser::ConvertGraphOutputs() {
|
||||||
cnode->set_fullname_with_scope("Return");
|
cnode->set_fullname_with_scope("Return");
|
||||||
res_graph_->set_return(cnode);
|
res_graph_->set_return(cnode);
|
||||||
} else {
|
} else {
|
||||||
auto returnPrim = std::make_shared<ops::Return>();
|
auto returnPrim = std::make_shared<lite::Return>();
|
||||||
if (returnPrim == nullptr) {
|
if (returnPrim == nullptr) {
|
||||||
MS_LOG(ERROR) << "new Return failed";
|
MS_LOG(ERROR) << "new Return failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -435,7 +433,7 @@ STATUS CaffeModelParser::ConvertTop(const caffe::LayerParameter &layer, const CN
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
abstract_list.emplace_back(abstract);
|
abstract_list.emplace_back(abstract);
|
||||||
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
|
auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>();
|
||||||
if (tuple_get_item_prim_ptr == nullptr) {
|
if (tuple_get_item_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new TupleGetItem failed";
|
MS_LOG(ERROR) << "new TupleGetItem failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
||||||
#include "ops/constant.h"
|
#include "tools/converter/ops/ops_def.h"
|
||||||
#include "tools/common/tensor_util.h"
|
#include "tools/common/tensor_util.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -46,7 +46,7 @@ STATUS OnnxConstantParser::AddDataInfoAttr(const onnx::TensorProto &onnx_const_t
|
||||||
}
|
}
|
||||||
|
|
||||||
ops::PrimitiveC *OnnxConstantParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
|
ops::PrimitiveC *OnnxConstantParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
|
||||||
auto prim = std::make_unique<ops::Constant>();
|
auto prim = std::make_unique<lite::Constant>();
|
||||||
|
|
||||||
for (const auto &attr : onnx_node.attribute()) {
|
for (const auto &attr : onnx_node.attribute()) {
|
||||||
if (attr.name() == "sparse_value") {
|
if (attr.name() == "sparse_value") {
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "tools/common/tensor_util.h"
|
#include "tools/common/tensor_util.h"
|
||||||
#include "ops/constant.h"
|
#include "tools/converter/ops/ops_def.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
|
@ -62,7 +62,7 @@ STATUS OnnxGivenTensorFillParser::ParseInt8GivenTensorFill(const onnx::NodeProto
|
||||||
}
|
}
|
||||||
ops::PrimitiveC *OnnxGivenTensorFillParser::Parse(const onnx::GraphProto &onnx_graph,
|
ops::PrimitiveC *OnnxGivenTensorFillParser::Parse(const onnx::GraphProto &onnx_graph,
|
||||||
const onnx::NodeProto &onnx_node) {
|
const onnx::NodeProto &onnx_node) {
|
||||||
auto prim = std::make_unique<ops::Constant>();
|
auto prim = std::make_unique<lite::Constant>();
|
||||||
|
|
||||||
std::vector<int64_t> shape_vector;
|
std::vector<int64_t> shape_vector;
|
||||||
auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(),
|
auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(),
|
||||||
|
|
|
@ -25,10 +25,8 @@
|
||||||
#include "tools/common/graph_util.h"
|
#include "tools/common/graph_util.h"
|
||||||
#include "tools/common/protobuf_utils.h"
|
#include "tools/common/protobuf_utils.h"
|
||||||
#include "tools/common/tensor_util.h"
|
#include "tools/common/tensor_util.h"
|
||||||
#include "ops/return.h"
|
#include "tools/converter/ops/ops_def.h"
|
||||||
#include "ops/make_tuple.h"
|
|
||||||
#include "ops/tensor_list_stack.h"
|
#include "ops/tensor_list_stack.h"
|
||||||
#include "ops/tuple_get_item.h"
|
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
#include "tools/converter/converter_flags.h"
|
#include "tools/converter/converter_flags.h"
|
||||||
|
|
||||||
|
@ -342,7 +340,7 @@ STATUS OnnxModelParser::ConvertGraphOutputs(const onnx::GraphProto &onnx_graph,
|
||||||
std::vector<AnfNodePtr> return_inputs;
|
std::vector<AnfNodePtr> return_inputs;
|
||||||
if (onnx_graph.output_size() > 1) {
|
if (onnx_graph.output_size() > 1) {
|
||||||
std::vector<AnfNodePtr> make_tuple_inputs;
|
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||||
auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
|
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>();
|
||||||
if (make_tuple_prim_ptr == nullptr) {
|
if (make_tuple_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new MakeTuple failed";
|
MS_LOG(ERROR) << "new MakeTuple failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -391,7 +389,7 @@ STATUS OnnxModelParser::BuildReturnNode(const FuncGraphPtr &anf_graph, const std
|
||||||
MS_LOG(ERROR) << "parameter has null.";
|
MS_LOG(ERROR) << "parameter has null.";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
auto returnPrim = std::make_shared<ops::Return>();
|
auto returnPrim = std::make_shared<lite::Return>();
|
||||||
if (returnPrim == nullptr) {
|
if (returnPrim == nullptr) {
|
||||||
MS_LOG(ERROR) << "new Return failed";
|
MS_LOG(ERROR) << "new Return failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -510,7 +508,7 @@ STATUS OnnxModelParser::BuildOpOutputs(const onnx::NodeProto &onnx_node, const F
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
abstract_list.emplace_back(abstract_tensor);
|
abstract_list.emplace_back(abstract_tensor);
|
||||||
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
|
auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>();
|
||||||
if (tuple_get_item_prim_ptr == nullptr) {
|
if (tuple_get_item_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new TupleGetItem failed";
|
MS_LOG(ERROR) << "new TupleGetItem failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
|
|
@ -24,9 +24,7 @@
|
||||||
#include "tools/common/protobuf_utils.h"
|
#include "tools/common/protobuf_utils.h"
|
||||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
#include "ops/return.h"
|
#include "tools/converter/ops/ops_def.h"
|
||||||
#include "ops/make_tuple.h"
|
|
||||||
#include "ops/tuple_get_item.h"
|
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "abstract/utils.h"
|
#include "abstract/utils.h"
|
||||||
#include "tools/converter/converter_flags.h"
|
#include "tools/converter/converter_flags.h"
|
||||||
|
@ -831,7 +829,7 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
abstract_list.emplace_back(abstract_tensor);
|
abstract_list.emplace_back(abstract_tensor);
|
||||||
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
|
auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>();
|
||||||
if (tuple_get_item_prim_ptr == nullptr) {
|
if (tuple_get_item_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new TupleGetItem failed";
|
MS_LOG(ERROR) << "new TupleGetItem failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -1038,7 +1036,7 @@ STATUS TFModelParser::MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes,
|
||||||
}
|
}
|
||||||
if (output_nodes->size() > 1) {
|
if (output_nodes->size() > 1) {
|
||||||
std::vector<AnfNodePtr> *make_tuple_inputs = output_nodes;
|
std::vector<AnfNodePtr> *make_tuple_inputs = output_nodes;
|
||||||
auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
|
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>();
|
||||||
if (make_tuple_prim_ptr == nullptr) {
|
if (make_tuple_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new MakeTuple failed";
|
MS_LOG(ERROR) << "new MakeTuple failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -1048,7 +1046,7 @@ STATUS TFModelParser::MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes,
|
||||||
auto make_tuple_cnode = anf_graph->NewCNode(*make_tuple_inputs);
|
auto make_tuple_cnode = anf_graph->NewCNode(*make_tuple_inputs);
|
||||||
make_tuple_cnode->set_fullname_with_scope("return tuple");
|
make_tuple_cnode->set_fullname_with_scope("return tuple");
|
||||||
|
|
||||||
auto return_prim_ptr = std::make_shared<ops::Return>();
|
auto return_prim_ptr = std::make_shared<lite::Return>();
|
||||||
if (return_prim_ptr == nullptr) {
|
if (return_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new Return failed";
|
MS_LOG(ERROR) << "new Return failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -1059,7 +1057,7 @@ STATUS TFModelParser::MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes,
|
||||||
cnode->set_fullname_with_scope("Return");
|
cnode->set_fullname_with_scope("Return");
|
||||||
anf_graph->set_return(cnode);
|
anf_graph->set_return(cnode);
|
||||||
} else {
|
} else {
|
||||||
auto return_prim_ptr = std::make_shared<ops::Return>();
|
auto return_prim_ptr = std::make_shared<lite::Return>();
|
||||||
if (return_prim_ptr == nullptr) {
|
if (return_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new Return failed";
|
MS_LOG(ERROR) << "new Return failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
|
|
@ -21,9 +21,7 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include "tools/converter/converter_flags.h"
|
#include "tools/converter/converter_flags.h"
|
||||||
#include "src/common/file_utils.h"
|
#include "src/common/file_utils.h"
|
||||||
#include "ops/return.h"
|
#include "tools/converter/ops/ops_def.h"
|
||||||
#include "ops/make_tuple.h"
|
|
||||||
#include "ops/tuple_get_item.h"
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/primitive_c.h"
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
|
|
||||||
|
@ -305,7 +303,7 @@ STATUS TfliteModelParser::ConvertGraphOutputs() {
|
||||||
const auto &tflite_subgraph = tflite_model_->subgraphs.front();
|
const auto &tflite_subgraph = tflite_model_->subgraphs.front();
|
||||||
if (tflite_subgraph->outputs.size() > 1) {
|
if (tflite_subgraph->outputs.size() > 1) {
|
||||||
std::vector<AnfNodePtr> make_tuple_inputs;
|
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||||
auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
|
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>();
|
||||||
if (make_tuple_prim_ptr == nullptr) {
|
if (make_tuple_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new MakeTuple failed";
|
MS_LOG(ERROR) << "new MakeTuple failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -325,7 +323,7 @@ STATUS TfliteModelParser::ConvertGraphOutputs() {
|
||||||
make_tuple_cnode->set_fullname_with_scope("return tuple");
|
make_tuple_cnode->set_fullname_with_scope("return tuple");
|
||||||
|
|
||||||
std::vector<AnfNodePtr> op_inputs;
|
std::vector<AnfNodePtr> op_inputs;
|
||||||
auto return_prim_ptr = std::make_shared<ops::Return>();
|
auto return_prim_ptr = std::make_shared<lite::Return>();
|
||||||
if (return_prim_ptr == nullptr) {
|
if (return_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new Return failed";
|
MS_LOG(ERROR) << "new Return failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -337,7 +335,7 @@ STATUS TfliteModelParser::ConvertGraphOutputs() {
|
||||||
cnode->set_fullname_with_scope("Return");
|
cnode->set_fullname_with_scope("Return");
|
||||||
res_graph_->set_return(cnode);
|
res_graph_->set_return(cnode);
|
||||||
} else {
|
} else {
|
||||||
auto returnPrim = std::make_shared<ops::Return>();
|
auto returnPrim = std::make_shared<lite::Return>();
|
||||||
if (returnPrim == nullptr) {
|
if (returnPrim == nullptr) {
|
||||||
MS_LOG(ERROR) << "new Return failed";
|
MS_LOG(ERROR) << "new Return failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -463,7 +461,7 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
abstract_list.emplace_back(abstract_tensor);
|
abstract_list.emplace_back(abstract_tensor);
|
||||||
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
|
auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>();
|
||||||
if (tuple_get_item_prim_ptr == nullptr) {
|
if (tuple_get_item_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new TupleGetItem failed";
|
MS_LOG(ERROR) << "new TupleGetItem failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
|
|
@ -34,7 +34,7 @@
|
||||||
#include "ops/fusion/full_connection.h"
|
#include "ops/fusion/full_connection.h"
|
||||||
#include "ops/fusion/layer_norm_fusion.h"
|
#include "ops/fusion/layer_norm_fusion.h"
|
||||||
#include "ops/gather.h"
|
#include "ops/gather.h"
|
||||||
#include "ops/tuple_get_item.h"
|
#include "tools/converter/ops/ops_def.h"
|
||||||
#include "src/tensor.h"
|
#include "src/tensor.h"
|
||||||
#include "tools/anf_exporter/anf_exporter.h"
|
#include "tools/anf_exporter/anf_exporter.h"
|
||||||
#include "tools/converter/quantizer/quant_cast.h"
|
#include "tools/converter/quantizer/quant_cast.h"
|
||||||
|
@ -414,7 +414,7 @@ STATUS Calibrator::ComputeThreshold() {
|
||||||
for (const auto &output_diverg_info : outputs_diverg_info.second) {
|
for (const auto &output_diverg_info : outputs_diverg_info.second) {
|
||||||
auto output_diverg_cnode = output_diverg_info->cnode;
|
auto output_diverg_cnode = output_diverg_info->cnode;
|
||||||
if (output_diverg_cnode == input_cnode) {
|
if (output_diverg_cnode == input_cnode) {
|
||||||
if (NodePrimitiveType(input_cnode) != ops::kNameTupleGetItem) {
|
if (NodePrimitiveType(input_cnode) != lite::kNameTupleGetItem) {
|
||||||
*(input_infos[i]) = *output_diverg_info;
|
*(input_infos[i]) = *output_diverg_info;
|
||||||
input_infos[i]->cnode = cnode;
|
input_infos[i]->cnode = cnode;
|
||||||
already_computed = true;
|
already_computed = true;
|
||||||
|
@ -801,7 +801,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
||||||
|
|
||||||
auto op_type = primitive->name();
|
auto op_type = primitive->name();
|
||||||
MS_LOG(DEBUG) << "OpName: " << op_name;
|
MS_LOG(DEBUG) << "OpName: " << op_name;
|
||||||
if (op_type == ops::kNameTupleGetItem) {
|
if (op_type == lite::kNameTupleGetItem) {
|
||||||
auto index_node = cnode->input(2);
|
auto index_node = cnode->input(2);
|
||||||
auto index_value_node = std::dynamic_pointer_cast<mindspore::ValueNode>(index_node);
|
auto index_value_node = std::dynamic_pointer_cast<mindspore::ValueNode>(index_node);
|
||||||
if (index_value_node == nullptr) {
|
if (index_value_node == nullptr) {
|
||||||
|
|
|
@ -41,7 +41,7 @@
|
||||||
#include "ops/reshape.h"
|
#include "ops/reshape.h"
|
||||||
#include "ops/split.h"
|
#include "ops/split.h"
|
||||||
#include "ops/transpose.h"
|
#include "ops/transpose.h"
|
||||||
#include "ops/tuple_get_item.h"
|
#include "tools/converter/ops/ops_def.h"
|
||||||
#include "tools/anf_exporter/anf_exporter.h"
|
#include "tools/anf_exporter/anf_exporter.h"
|
||||||
#include "tools/converter/quantizer/bitpacking.h"
|
#include "tools/converter/quantizer/bitpacking.h"
|
||||||
#include "src/common/utils.h"
|
#include "src/common/utils.h"
|
||||||
|
@ -113,7 +113,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
|
||||||
ops::kNameCrop, ops::kNameEltwise, ops::kNameFullConnection,
|
ops::kNameCrop, ops::kNameEltwise, ops::kNameFullConnection,
|
||||||
ops::kNameGather, ops::kNameLayerNormFusion, ops::kNameMatMul,
|
ops::kNameGather, ops::kNameLayerNormFusion, ops::kNameMatMul,
|
||||||
ops::kNameMaxPoolFusion, ops::kNameMulFusion, ops::kNameReshape,
|
ops::kNameMaxPoolFusion, ops::kNameMulFusion, ops::kNameReshape,
|
||||||
ops::kNameSplit, ops::kNameTranspose, ops::kNameTupleGetItem,
|
ops::kNameSplit, ops::kNameTranspose, lite::kNameTupleGetItem,
|
||||||
};
|
};
|
||||||
bool contain = IsContain(int8OpList, type);
|
bool contain = IsContain(int8OpList, type);
|
||||||
if (!contain) {
|
if (!contain) {
|
||||||
|
|
|
@ -23,7 +23,7 @@
|
||||||
#include "Eigen/Core"
|
#include "Eigen/Core"
|
||||||
#include "ops/fusion/conv2d_fusion.h"
|
#include "ops/fusion/conv2d_fusion.h"
|
||||||
#include "ops/transpose.h"
|
#include "ops/transpose.h"
|
||||||
#include "ops/tuple_get_item.h"
|
#include "tools/converter/ops/ops_def.h"
|
||||||
#include "src/common/common.h"
|
#include "src/common/common.h"
|
||||||
#include "tools/common/tensor_util.h"
|
#include "tools/common/tensor_util.h"
|
||||||
#include "frontend/operator/ops.h"
|
#include "frontend/operator/ops.h"
|
||||||
|
@ -1421,7 +1421,7 @@ CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &inpu
|
||||||
|
|
||||||
CNodePtr GenTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &input, size_t index) {
|
CNodePtr GenTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &input, size_t index) {
|
||||||
MS_ASSERT(func_graph != nullptr && input != nullptr);
|
MS_ASSERT(func_graph != nullptr && input != nullptr);
|
||||||
auto tuple_get_item_prim = std::make_shared<ops::TupleGetItem>();
|
auto tuple_get_item_prim = std::make_shared<lite::TupleGetItem>();
|
||||||
auto second_input = NewValueNode(MakeValue<int>(index));
|
auto second_input = NewValueNode(MakeValue<int>(index));
|
||||||
auto tuple_cnode = func_graph->NewCNode(tuple_get_item_prim, {input, second_input});
|
auto tuple_cnode = func_graph->NewCNode(tuple_get_item_prim, {input, second_input});
|
||||||
tuple_cnode->set_fullname_with_scope(input->fullname_with_scope() + "_getitem_" + std::to_string(index));
|
tuple_cnode->set_fullname_with_scope(input->fullname_with_scope() + "_getitem_" + std::to_string(index));
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include "ops/lstm.h"
|
#include "ops/lstm.h"
|
||||||
#include "ops/squeeze.h"
|
#include "ops/squeeze.h"
|
||||||
#include "ops/tuple_get_item.h"
|
#include "tools/converter/ops/ops_def.h"
|
||||||
#include "src/common/utils.h"
|
#include "src/common/utils.h"
|
||||||
#include "tools/common/tensor_util.h"
|
#include "tools/common/tensor_util.h"
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
|
@ -495,7 +495,7 @@ CNodePtr TfliteLstmCellFusion::CreateOutputGetItem(const FuncGraphPtr &func_grap
|
||||||
MS_ASSERT(func_graph != nullptr);
|
MS_ASSERT(func_graph != nullptr);
|
||||||
MS_ASSERT(node != nullptr);
|
MS_ASSERT(node != nullptr);
|
||||||
MS_ASSERT(get_items != nullptr);
|
MS_ASSERT(get_items != nullptr);
|
||||||
auto tuple_get_item_prim = std::make_shared<ops::TupleGetItem>();
|
auto tuple_get_item_prim = std::make_shared<lite::TupleGetItem>();
|
||||||
auto get_item_value = NewValueNode(MakeValue<int>(item_index));
|
auto get_item_value = NewValueNode(MakeValue<int>(item_index));
|
||||||
if (tuple_get_item_prim == nullptr || get_item_value == nullptr) {
|
if (tuple_get_item_prim == nullptr || get_item_value == nullptr) {
|
||||||
MS_LOG(ERROR) << "NewValueNode is nullptr";
|
MS_LOG(ERROR) << "NewValueNode is nullptr";
|
||||||
|
|
|
@ -20,9 +20,7 @@
|
||||||
#include <deque>
|
#include <deque>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
#include "ops/make_tuple.h"
|
|
||||||
#include "tools/converter/ops/ops_def.h"
|
#include "tools/converter/ops/ops_def.h"
|
||||||
#include "ops/return.h"
|
|
||||||
|
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
|
|
||||||
|
@ -147,7 +145,7 @@ FuncGraphPtr FunctionalizeCond::CreateBranchGraph(const AnfNodePtr &node, std::s
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!CheckPrimitiveType(node, prim::kPrimSwitch)) { // graph is not empty
|
if (!CheckPrimitiveType(node, prim::kPrimSwitch)) { // graph is not empty
|
||||||
auto return_prim_ptr = std::make_shared<ops::Return>();
|
auto return_prim_ptr = std::make_shared<lite::Return>();
|
||||||
if (return_prim_ptr == nullptr) {
|
if (return_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
|
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -19,9 +19,7 @@
|
||||||
#include <deque>
|
#include <deque>
|
||||||
#include "tools/optimizer/graph/functionalize_while.h"
|
#include "tools/optimizer/graph/functionalize_while.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
#include "ops/make_tuple.h"
|
#include "tools/converter/ops/ops_def.h"
|
||||||
#include "ops/return.h"
|
|
||||||
#include "ops/tuple_get_item.h"
|
|
||||||
#include "tools/converter/ops/while.h"
|
#include "tools/converter/ops/while.h"
|
||||||
#include "tools/common/tensor_util.h"
|
#include "tools/common/tensor_util.h"
|
||||||
|
|
||||||
|
@ -215,7 +213,7 @@ STATUS FunctionalizeWhile::UpdateExitNodeUser() {
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
abstract_list.emplace_back(abstract);
|
abstract_list.emplace_back(abstract);
|
||||||
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
|
auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>();
|
||||||
if (tuple_get_item_prim_ptr == nullptr) {
|
if (tuple_get_item_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr";
|
MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -346,7 +344,7 @@ STATUS FunctionalizeWhile::IdentifyCondSubgraphInput() {
|
||||||
}
|
}
|
||||||
|
|
||||||
STATUS FunctionalizeWhile::IdentifyCondSubgraphOutput() {
|
STATUS FunctionalizeWhile::IdentifyCondSubgraphOutput() {
|
||||||
auto return_prim_ptr = std::make_shared<ops::Return>();
|
auto return_prim_ptr = std::make_shared<lite::Return>();
|
||||||
if (return_prim_ptr == nullptr) {
|
if (return_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
|
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -491,7 +489,7 @@ STATUS FunctionalizeWhile::IdentifyBodySubgraphOutput() {
|
||||||
"_cnode");
|
"_cnode");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto return_prim_ptr = std::make_shared<ops::Return>();
|
auto return_prim_ptr = std::make_shared<lite::Return>();
|
||||||
if (return_prim_ptr == nullptr) {
|
if (return_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
|
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -506,7 +504,7 @@ STATUS FunctionalizeWhile::IdentifyBodySubgraphOutput() {
|
||||||
return_cnode->add_input(tmp_output[0]);
|
return_cnode->add_input(tmp_output[0]);
|
||||||
} else {
|
} else {
|
||||||
std::vector<AnfNodePtr> make_tuple_inputs = tmp_output;
|
std::vector<AnfNodePtr> make_tuple_inputs = tmp_output;
|
||||||
auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
|
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>();
|
||||||
if (make_tuple_prim_ptr == nullptr) {
|
if (make_tuple_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr";
|
MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
|
|
@ -21,7 +21,6 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "ops/batch_norm.h"
|
#include "ops/batch_norm.h"
|
||||||
#include "ops/elu.h"
|
#include "ops/elu.h"
|
||||||
#include "ops/depthwise_conv2d.h"
|
|
||||||
#include "ops/fused_batch_norm.h"
|
#include "ops/fused_batch_norm.h"
|
||||||
#include "ops/fusion/activation.h"
|
#include "ops/fusion/activation.h"
|
||||||
#include "ops/fusion/add_fusion.h"
|
#include "ops/fusion/add_fusion.h"
|
||||||
|
@ -85,7 +84,6 @@ using mindspore::ops::kNameConv2D;
|
||||||
using mindspore::ops::kNameConv2DBackpropFilter;
|
using mindspore::ops::kNameConv2DBackpropFilter;
|
||||||
using mindspore::ops::kNameConv2DBackpropInput;
|
using mindspore::ops::kNameConv2DBackpropInput;
|
||||||
using mindspore::ops::kNameConv2dTranspose;
|
using mindspore::ops::kNameConv2dTranspose;
|
||||||
using mindspore::ops::kNameDepthWiseConv2D;
|
|
||||||
using mindspore::ops::kNameDiv;
|
using mindspore::ops::kNameDiv;
|
||||||
using mindspore::ops::kNameElu;
|
using mindspore::ops::kNameElu;
|
||||||
using mindspore::ops::kNameExp;
|
using mindspore::ops::kNameExp;
|
||||||
|
@ -571,7 +569,6 @@ REGIST_PRIMITIVE_ADJUST(kNameBatchNorm, MoveAttrMapCommon<ops::FusedBatchNorm>)
|
||||||
REGIST_PRIMITIVE_ADJUST(kNameConv2DBackpropFilter, MoveAttrMapCommon<ops::Conv2DBackpropFilterFusion>)
|
REGIST_PRIMITIVE_ADJUST(kNameConv2DBackpropFilter, MoveAttrMapCommon<ops::Conv2DBackpropFilterFusion>)
|
||||||
REGIST_PRIMITIVE_ADJUST(kNameConv2DBackpropInput, MoveAttrMapCommon<ops::Conv2DBackpropInputFusion>)
|
REGIST_PRIMITIVE_ADJUST(kNameConv2DBackpropInput, MoveAttrMapCommon<ops::Conv2DBackpropInputFusion>)
|
||||||
REGIST_PRIMITIVE_ADJUST(kNameConv2D, MoveAttrMapConv2D)
|
REGIST_PRIMITIVE_ADJUST(kNameConv2D, MoveAttrMapConv2D)
|
||||||
REGIST_PRIMITIVE_ADJUST(kNameDepthWiseConv2D, MoveAttrMapConv2D)
|
|
||||||
REGIST_PRIMITIVE_ADJUST(kNameConv2dTranspose, MoveAttrMapCommon<ops::Conv2dTransposeFusion>)
|
REGIST_PRIMITIVE_ADJUST(kNameConv2dTranspose, MoveAttrMapCommon<ops::Conv2dTransposeFusion>)
|
||||||
REGIST_PRIMITIVE_ADJUST(kNameDiv, MoveAttrMapCommon<ops::DivFusion>)
|
REGIST_PRIMITIVE_ADJUST(kNameDiv, MoveAttrMapCommon<ops::DivFusion>)
|
||||||
REGIST_PRIMITIVE_ADJUST(kNameElu, MoveAttrMapActivation)
|
REGIST_PRIMITIVE_ADJUST(kNameElu, MoveAttrMapActivation)
|
||||||
|
|
|
@ -18,8 +18,8 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
|
#include "tools/converter/ops/ops_def.h"
|
||||||
#include "ops/depend.h"
|
#include "ops/depend.h"
|
||||||
#include "ops/make_tuple.h"
|
|
||||||
|
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -95,7 +95,7 @@ int ProcessInputHaveDependency(const FuncGraphPtr &func_graph, const CNodePtr &c
|
||||||
if (ProcessDependencyWithTwoNodes(func_graph, cnode, false) == lite::RET_OK) {
|
if (ProcessDependencyWithTwoNodes(func_graph, cnode, false) == lite::RET_OK) {
|
||||||
return lite::RET_OK;
|
return lite::RET_OK;
|
||||||
}
|
}
|
||||||
auto make_tuple_prim = NewValueNode(std::make_shared<ops::MakeTuple>());
|
auto make_tuple_prim = NewValueNode(std::make_shared<lite::MakeTuple>());
|
||||||
auto manager = func_graph->manager();
|
auto manager = func_graph->manager();
|
||||||
MS_ASSERT(manager != nullptr);
|
MS_ASSERT(manager != nullptr);
|
||||||
manager->Replace(cnode->input(0), make_tuple_prim);
|
manager->Replace(cnode->input(0), make_tuple_prim);
|
||||||
|
|
|
@ -1,73 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
#include "common/common_test.h"
|
|
||||||
#include "ops/batch_norm_fold.h"
|
|
||||||
#include "ir/dtype/type.h"
|
|
||||||
#include "ir/value.h"
|
|
||||||
#include "abstract/dshape.h"
|
|
||||||
#include "abstract/abstract_value.h"
|
|
||||||
#include "utils/tensor_construct_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
|
|
||||||
class TestBatchNormFold : public UT::Common {
|
|
||||||
public:
|
|
||||||
TestBatchNormFold() {}
|
|
||||||
void SetUp() {}
|
|
||||||
void TearDown() {}
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(TestBatchNormFold, test_ops_batch_norm_fold1) {
|
|
||||||
auto batch_norm_fold = std::make_shared<BatchNormFold>();
|
|
||||||
batch_norm_fold->Init(0.9, 1e-5, true, 0);
|
|
||||||
EXPECT_EQ((int64_t)(batch_norm_fold->get_momentum() - 0.9), 0);
|
|
||||||
EXPECT_EQ((int64_t)(batch_norm_fold->get_epsilon() - 1e-05), 0);
|
|
||||||
EXPECT_EQ(batch_norm_fold->get_is_training(), true);
|
|
||||||
EXPECT_EQ(batch_norm_fold->get_freeze_bn(), 0);
|
|
||||||
auto input_x = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, std::vector<int64_t>{2, 3});
|
|
||||||
auto mean = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, std::vector<int64_t>{3});
|
|
||||||
auto variance = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, std::vector<int64_t>{3});
|
|
||||||
auto global_step = TensorConstructUtils::CreateOnesTensor(kNumberTypeInt32, std::vector<int64_t>{1});
|
|
||||||
auto abstract = batch_norm_fold->Infer(
|
|
||||||
{input_x->ToAbstract(), mean->ToAbstract(), variance->ToAbstract(), global_step->ToAbstract()});
|
|
||||||
MS_EXCEPTION_IF_NULL(abstract);
|
|
||||||
EXPECT_EQ(abstract->isa<abstract::AbstractTuple>(), true);
|
|
||||||
auto shape_ptr = abstract->BuildShape();
|
|
||||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
|
||||||
EXPECT_EQ(shape_ptr->isa<abstract::TupleShape>(), true);
|
|
||||||
auto shape = shape_ptr->cast<abstract::TupleShapePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(shape);
|
|
||||||
auto shape_vec = shape->shape();
|
|
||||||
EXPECT_EQ(shape_vec.size(), 4);
|
|
||||||
auto shape1 = shape_vec[0]->cast<abstract::ShapePtr>()->shape();
|
|
||||||
EXPECT_EQ(shape1.size(), 1);
|
|
||||||
EXPECT_EQ(shape1[0], 3);
|
|
||||||
auto type_ptr = abstract->BuildType();
|
|
||||||
MS_EXCEPTION_IF_NULL(type_ptr);
|
|
||||||
auto type = type_ptr->cast<TuplePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(type);
|
|
||||||
auto type_vec = type->elements();
|
|
||||||
MS_EXCEPTION_IF_NULL(type_vec[0]);
|
|
||||||
auto data_type = type_vec[0]->cast<TensorTypePtr>()->element();
|
|
||||||
MS_EXCEPTION_IF_NULL(data_type);
|
|
||||||
EXPECT_EQ(data_type->type_id(), kNumberTypeFloat32);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,62 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
#include "common/common_test.h"
|
|
||||||
#include "ops/constant.h"
|
|
||||||
#include "ir/dtype/type.h"
|
|
||||||
#include "ir/value.h"
|
|
||||||
#include "abstract/dshape.h"
|
|
||||||
#include "utils/tensor_construct_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
class TestConstant : public UT::Common {
|
|
||||||
public:
|
|
||||||
TestConstant() {}
|
|
||||||
void SetUp() {}
|
|
||||||
void TearDown() {}
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(TestConstant, test_ops_constant1) {
|
|
||||||
auto constant = std::make_shared<Constant>();
|
|
||||||
auto tensor_x = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, std::vector<int64_t>{2, 3, 4, 5});
|
|
||||||
MS_EXCEPTION_IF_NULL(tensor_x);
|
|
||||||
auto abstract = constant->Infer({tensor_x->ToAbstract()});
|
|
||||||
MS_EXCEPTION_IF_NULL(abstract);
|
|
||||||
EXPECT_EQ(abstract->isa<abstract::AbstractTensor>(), true);
|
|
||||||
auto shape_ptr = abstract->BuildShape();
|
|
||||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
|
||||||
EXPECT_EQ(shape_ptr->isa<abstract::Shape>(), true);
|
|
||||||
auto shape = shape_ptr->cast<abstract::ShapePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(shape);
|
|
||||||
auto shape_vec = shape->shape();
|
|
||||||
auto type = abstract->BuildType();
|
|
||||||
MS_EXCEPTION_IF_NULL(type);
|
|
||||||
EXPECT_EQ(type->isa<TensorType>(), true);
|
|
||||||
auto tensor_type = type->cast<TensorTypePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
|
||||||
auto data_type = tensor_type->element();
|
|
||||||
MS_EXCEPTION_IF_NULL(data_type);
|
|
||||||
EXPECT_EQ(data_type->type_id(), kNumberTypeFloat32);
|
|
||||||
EXPECT_EQ(shape_vec.size(), 4);
|
|
||||||
EXPECT_EQ(shape_vec[0], 2);
|
|
||||||
EXPECT_EQ(shape_vec[1], 3);
|
|
||||||
EXPECT_EQ(shape_vec[2], 4);
|
|
||||||
EXPECT_EQ(shape_vec[3], 5);
|
|
||||||
}
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,62 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
#include "common/common_test.h"
|
|
||||||
#include "ops/local_response_normalization.h"
|
|
||||||
#include "ir/dtype/type.h"
|
|
||||||
#include "ir/value.h"
|
|
||||||
#include "abstract/dshape.h"
|
|
||||||
#include "utils/tensor_construct_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
class TestLocalResponseNormalization : public UT::Common {
|
|
||||||
public:
|
|
||||||
TestLocalResponseNormalization() {}
|
|
||||||
void SetUp() {}
|
|
||||||
void TearDown() {}
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(TestLocalResponseNormalization, test_ops_local_response_norm1) {
|
|
||||||
auto local_response_norm = std::make_shared<LocalResponseNormalization>();
|
|
||||||
auto tensor_x = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat16, std::vector<int64_t>{2, 3, 4, 5});
|
|
||||||
MS_EXCEPTION_IF_NULL(tensor_x);
|
|
||||||
auto abstract = local_response_norm->Infer({tensor_x->ToAbstract()});
|
|
||||||
MS_EXCEPTION_IF_NULL(abstract);
|
|
||||||
EXPECT_EQ(abstract->isa<abstract::AbstractTensor>(), true);
|
|
||||||
auto shape_ptr = abstract->BuildShape();
|
|
||||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
|
||||||
EXPECT_EQ(shape_ptr->isa<abstract::Shape>(), true);
|
|
||||||
auto shape = shape_ptr->cast<abstract::ShapePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(shape);
|
|
||||||
auto shape_vec = shape->shape();
|
|
||||||
auto type = abstract->BuildType();
|
|
||||||
MS_EXCEPTION_IF_NULL(type);
|
|
||||||
EXPECT_EQ(type->isa<TensorType>(), true);
|
|
||||||
auto tensor_type = type->cast<TensorTypePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
|
||||||
auto data_type = tensor_type->element();
|
|
||||||
MS_EXCEPTION_IF_NULL(data_type);
|
|
||||||
EXPECT_EQ(data_type->type_id(), kNumberTypeFloat16);
|
|
||||||
EXPECT_EQ(shape_vec.size(), 4);
|
|
||||||
EXPECT_EQ(shape_vec[0], 2);
|
|
||||||
EXPECT_EQ(shape_vec[1], 3);
|
|
||||||
EXPECT_EQ(shape_vec[2], 4);
|
|
||||||
EXPECT_EQ(shape_vec[3], 5);
|
|
||||||
}
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,65 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
#include "common/common_test.h"
|
|
||||||
#include "ops/matrix_diag.h"
|
|
||||||
#include "ir/dtype/type.h"
|
|
||||||
#include "ir/value.h"
|
|
||||||
#include "abstract/dshape.h"
|
|
||||||
#include "utils/tensor_construct_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
|
|
||||||
class TestMatrixDiag : public UT::Common {
|
|
||||||
public:
|
|
||||||
TestMatrixDiag() {}
|
|
||||||
void SetUp() {}
|
|
||||||
void TearDown() {}
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(TestMatrixDiag, test_ops_matrix_diag1) {
|
|
||||||
auto matrix_diag = std::make_shared<MatrixDiag>();
|
|
||||||
auto input0 = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, std::vector<int64_t>{2});
|
|
||||||
auto input1 = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, std::vector<int64_t>{3, 2, 2});
|
|
||||||
MS_EXCEPTION_IF_NULL(input0);
|
|
||||||
MS_EXCEPTION_IF_NULL(input1);
|
|
||||||
auto abstract = matrix_diag->Infer({input0->ToAbstract(), input1->ToAbstract()});
|
|
||||||
MS_EXCEPTION_IF_NULL(abstract);
|
|
||||||
EXPECT_EQ(abstract->isa<abstract::AbstractTensor>(), true);
|
|
||||||
auto shape_ptr = abstract->BuildShape();
|
|
||||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
|
||||||
EXPECT_EQ(shape_ptr->isa<abstract::Shape>(), true);
|
|
||||||
auto shape = shape_ptr->cast<abstract::ShapePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(shape);
|
|
||||||
auto shape_vec = shape->shape();
|
|
||||||
EXPECT_EQ(shape_vec.size(), 3);
|
|
||||||
EXPECT_EQ(shape_vec[0], 3);
|
|
||||||
EXPECT_EQ(shape_vec[1], 2);
|
|
||||||
EXPECT_EQ(shape_vec[2], 2);
|
|
||||||
auto type = abstract->BuildType();
|
|
||||||
MS_EXCEPTION_IF_NULL(type);
|
|
||||||
EXPECT_EQ(type->isa<TensorType>(), true);
|
|
||||||
auto tensor_type = type->cast<TensorTypePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
|
||||||
auto data_type = tensor_type->element();
|
|
||||||
MS_EXCEPTION_IF_NULL(data_type);
|
|
||||||
EXPECT_EQ(data_type->type_id(), kNumberTypeFloat32);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
Loading…
Reference in New Issue