!20912 code check for ops

Merge pull request !20912 from Simson/codex
This commit is contained in:
i-robot 2021-07-29 02:21:09 +00:00 committed by Gitee
commit ad589e6780
127 changed files with 386 additions and 162 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -39,6 +39,7 @@ class RecvGpuKernel : public GpuKernel {
return true;
}
bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_node_ = kernel_node;
wait_stream_ = reinterpret_cast<cudaStream_t>(GetAttr<uintptr_t>(kernel_node, "wait_event_stream"));
wait_event_ = reinterpret_cast<cudaEvent_t>(GetAttr<uintptr_t>(kernel_node, "wait_event"));

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -39,6 +39,7 @@ class SendGpuKernel : public GpuKernel {
return true;
}
bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_node_ = kernel_node;
record_stream_ = reinterpret_cast<cudaStream_t>(GetAttr<uintptr_t>(kernel_node, "record_event_stream"));
record_event_ = reinterpret_cast<cudaEvent_t>(GetAttr<uintptr_t>(kernel_node, "record_event"));

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -34,11 +34,17 @@ const std::vector<size_t> &DatasetInitKernel::GetOutputSizeList() const { return
const std::vector<size_t> &DatasetInitKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool DatasetInitKernel::Init(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
queue_name_ = GetAttr<std::string>(kernel_node, "queue_name");
std::vector<std::vector<int>> shapes;
std::vector<TypePtr> types;
GetShapeAndType(kernel_node, &shapes, &types);
for (auto item : types) {
MS_EXCEPTION_IF_NULL(item);
}
if (types.size() < shapes.size()) {
MS_LOG(EXCEPTION) << "types size is less than shapes size.";
}
for (size_t i = 0; i < shapes.size(); i++) {
int unit = UnitSizeInBytes(types[i]->type_id());
int nums = ElementNums(shapes[i]);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -47,12 +47,18 @@ const std::vector<size_t> &DatasetIteratorKernel::GetOutputSizeList() const { re
const std::vector<size_t> &DatasetIteratorKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool DatasetIteratorKernel::Init(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_node_ = kernel_node;
queue_name_ = GetAttr<std::string>(kernel_node, "shared_name");
std::vector<std::vector<int>> shapes;
std::vector<TypePtr> types;
GetShapeAndType(kernel_node, &shapes, &types);
for (auto item : types) {
MS_EXCEPTION_IF_NULL(item);
}
if (types.size() < shapes.size()) {
MS_LOG(EXCEPTION) << "types size is less than shapes size.";
}
for (size_t i = 0; i < shapes.size(); i++) {
int unit = UnitSizeInBytes(types[i]->type_id());
int nums = ElementNums(shapes[i]);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -48,6 +48,9 @@ void GetNextProfiling::SaveProfilingData() {
return;
}
for (uint32_t index = 0; index < queue_size_.size(); index++) {
if (index > time_stamp_.size() - 1) {
MS_LOG(EXCEPTION) << "index exceeds time_stamp_ size.";
}
handle << Name() << " " << time_stamp_[index].first << " " << time_stamp_[index].second << " " << queue_size_[index]
<< std::endl;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -68,6 +68,7 @@ int ElementNums(const std::vector<int> &shape) {
}
void GetShapeAndType(const CNodePtr &kernel_node, std::vector<std::vector<int>> *shapes, std::vector<TypePtr> *types) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(shapes);
MS_EXCEPTION_IF_NULL(types);
std::vector<std::vector<int64_t>> shapes_me =

View File

@ -87,6 +87,7 @@ class PrintGpuKernel : public GpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_node_ = kernel_node;
if (AnfAlgo::HasNodeAttr("string_pos", kernel_node)) {
string_value_ = GetAttr<std::vector<std::string>>(kernel_node, "string_value");

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -77,6 +77,7 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_node_ = kernel_node;
nccl_data_type_ = nccl_dtype(AnfAlgo::GetInputDeviceDataType(kernel_node, 0));
InferCommType(kernel_node);

View File

@ -49,6 +49,7 @@ class AssignGpuKernel : public GpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_node_ = kernel_node;
if (!CheckParam(kernel_node)) {
return false;
@ -71,6 +72,7 @@ class AssignGpuKernel : public GpuKernel {
private:
bool CheckParam(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but AssignGpuKernel needs 2 output.";

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -62,6 +62,7 @@ class BoundingBoxDecodeGpuKernel : public GpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but BoundingBoxDecode needs 2 inputs.";
@ -89,10 +90,11 @@ class BoundingBoxDecodeGpuKernel : public GpuKernel {
InitSizeLists();
const size_t coordinate_size = 4;
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueTuple>() ||
AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueList>()) {
auto means = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means");
MS_EXCEPTION_IF_NULL(means);
if (means->isa<ValueTuple>() || means->isa<ValueList>()) {
means_ = GetAttr<std::vector<float>>(kernel_node, "means");
} else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<FloatImm>()) {
} else if (means->isa<FloatImm>()) {
float mean = GetAttr<float>(kernel_node, "means");
for (size_t i = 0; i < coordinate_size; i++) {
means_.emplace_back(mean);
@ -101,10 +103,11 @@ class BoundingBoxDecodeGpuKernel : public GpuKernel {
MS_LOG(EXCEPTION) << "Attribute means type is invalid.";
}
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueTuple>() ||
AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueList>()) {
auto stds = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds");
MS_EXCEPTION_IF_NULL(stds);
if (stds->isa<ValueTuple>() || stds->isa<ValueList>()) {
stds_ = GetAttr<std::vector<float>>(kernel_node, "stds");
} else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<FloatImm>()) {
} else if (stds->isa<FloatImm>()) {
float std = GetAttr<float>(kernel_node, "stds");
for (size_t i = 0; i < coordinate_size; i++) {
stds_.emplace_back(std);

View File

@ -61,6 +61,7 @@ class BoundingBoxEncodeGpuKernel : public GpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but BoundingBoxEncode needs 2 inputs.";
@ -88,10 +89,11 @@ class BoundingBoxEncodeGpuKernel : public GpuKernel {
InitSizeLists();
const size_t coordinate_size = 4;
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueTuple>() ||
AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueList>()) {
auto means = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means");
MS_EXCEPTION_IF_NULL(means);
if (means->isa<ValueTuple>() || means->isa<ValueList>()) {
means_ = GetAttr<std::vector<float>>(kernel_node, "means");
} else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<FloatImm>()) {
} else if (means->isa<FloatImm>()) {
float mean = GetAttr<float>(kernel_node, "means");
for (size_t i = 0; i < coordinate_size; i++) {
means_.emplace_back(mean);
@ -99,11 +101,11 @@ class BoundingBoxEncodeGpuKernel : public GpuKernel {
} else {
MS_LOG(EXCEPTION) << "Attribute means type is invalid.";
}
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueTuple>() ||
AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueList>()) {
auto stds = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds");
MS_EXCEPTION_IF_NULL(stds);
if (stds->isa<ValueTuple>() || stds->isa<ValueList>()) {
stds_ = GetAttr<std::vector<float>>(kernel_node, "stds");
} else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<FloatImm>()) {
} else if (stds->isa<FloatImm>()) {
float std = GetAttr<float>(kernel_node, "stds");
for (size_t i = 0; i < coordinate_size; i++) {
stds_.emplace_back(std);

View File

@ -55,6 +55,7 @@ class CheckValidGpuKernel : public GpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but CheckValid needs 2 inputs.";

View File

@ -59,6 +59,7 @@ class GpuConvertToDynamicShapeGpuKernel : public GpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_node_ = kernel_node;
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_count != 1) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -59,6 +59,7 @@ class IOUGpuKernel : public GpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but IOU needs 2 inputs.";

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -89,6 +89,7 @@ class RandomCategoricalGpuKernel : public GpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_node_ = kernel_node;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {

View File

@ -70,6 +70,7 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
uint32_t time_interval = std::chrono::system_clock::now().time_since_epoch().count();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {

View File

@ -76,6 +76,7 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_node_ = kernel_node;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
@ -173,6 +174,7 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
MS_LOG(EXCEPTION) << "range_max_ failed to cast";
}
range = static_cast<S>(range_max_);
MS_EXCEPTION_IF_ZERO("range", range);
return static_cast<S>(1.0f / range);
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -42,6 +42,7 @@ void AdderFusion::set_activation_type(const ActivationType activation_type) {
ActivationType AdderFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr));
}
REGISTER_PRIMITIVE_C(kNameAdderFusion, AdderFusion);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -31,9 +31,21 @@ void ArgMaxFusion::set_out_max_value(const bool out_max_value) {
}
void ArgMaxFusion::set_top_k(const int64_t top_k) { (void)this->AddAttr(kTopK, MakeValue(top_k)); }
bool ArgMaxFusion::get_keep_dims() const { return GetValue<bool>(GetAttr(kKeepDims)); }
bool ArgMaxFusion::get_out_max_value() const { return GetValue<bool>(GetAttr(kOutMaxValue)); }
int64_t ArgMaxFusion::get_top_k() const { return GetValue<int64_t>(GetAttr(kTopK)); }
bool ArgMaxFusion::get_keep_dims() const {
auto keep_dims = GetAttr(kKeepDims);
MS_EXCEPTION_IF_NULL(keep_dims);
return GetValue<bool>(keep_dims);
}
bool ArgMaxFusion::get_out_max_value() const {
auto out_maxv = GetAttr(kOutMaxValue);
MS_EXCEPTION_IF_NULL(out_maxv);
return GetValue<bool>(out_maxv);
}
int64_t ArgMaxFusion::get_top_k() const {
auto topk = GetAttr(kTopK);
MS_EXCEPTION_IF_NULL(topk);
return GetValue<int64_t>(topk);
}
REGISTER_PRIMITIVE_C(kNameArgMaxFusion, ArgMaxFusion);
} // namespace ops

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -29,14 +29,21 @@ void ArgMinFusion::set_keep_dims(const bool keep_dims) { (void)this->AddAttr(kKe
void ArgMinFusion::set_out_max_value(bool out_max_value) { (void)AddAttr(kOutMaxValue, MakeValue(out_max_value)); }
void ArgMinFusion::set_top_k(int64_t top_k) { (void)this->AddAttr(kTopK, MakeValue(top_k)); }
bool ArgMinFusion::get_keep_dims() const { return GetValue<bool>(GetAttr(kKeepDims)); }
bool ArgMinFusion::get_keep_dims() const {
auto keep_dims = GetAttr(kKeepDims);
MS_EXCEPTION_IF_NULL(keep_dims);
return GetValue<bool>(keep_dims);
}
bool ArgMinFusion::get_out_max_value() const {
auto value_ptr = GetAttr(kOutMaxValue);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr);
}
int64_t ArgMinFusion::get_top_k() const {
auto value_ptr = GetAttr(kTopK);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -40,17 +40,22 @@ void AvgPoolFusion::set_activation_type(ActivationType activation_type) {
bool AvgPoolFusion::get_global() const {
auto value_ptr = GetAttr(kGlobal);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr);
}
ActivationType AvgPoolFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr));
}
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
for (auto item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
@ -66,6 +71,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto in_w = in_shape[3];
auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
(void)CheckAndConvertUtils::CheckPositiveVector(kStride, strides, op_name);
auto kernel_h = kernel_size[2];
auto kernel_w = kernel_size[3];
auto stride_h = strides[2];
@ -90,8 +96,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr arg) { return arg == nullptr; })) {
MS_LOG(EXCEPTION) << "nullptr";
for (auto item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
return input_args[0]->BuildType();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -52,11 +52,13 @@ void Conv2DBackpropFilterFusion::set_in_channel(const int64_t in_channel) {
ActivationType Conv2DBackpropFilterFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr));
}
int64_t Conv2DBackpropFilterFusion::get_in_channel() const {
auto value_ptr = GetAttr(kInChannel);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameConv2DBackpropFilterFusion, Conv2DBackpropFilterFusion);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -51,11 +51,13 @@ void Conv2DBackpropInputFusion::set_activation_type(const ActivationType &activa
}
int64_t Conv2DBackpropInputFusion::get_in_channel() const {
auto value_ptr = GetAttr(kInChannel);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
ActivationType Conv2DBackpropInputFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr));
}
REGISTER_PRIMITIVE_C(kNameConv2DBackpropInputFusion, Conv2DBackpropInputFusion);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -48,14 +48,17 @@ void Conv2DFusion::set_activation_type(const ActivationType &activation_type) {
}
int64_t Conv2DFusion::get_in_channel() const {
auto value_ptr = GetAttr(kInChannel);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
std::vector<int64_t> Conv2DFusion::get_pad_list() const {
auto value_ptr = GetAttr(kPadList);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
ActivationType Conv2DFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr));
}
REGISTER_PRIMITIVE_C(kNameConv2DFusion, Conv2DFusion);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -71,11 +71,13 @@ void Conv2dTransposeFusion::set_activation_type(ActivationType activation_type)
std::vector<int64_t> Conv2dTransposeFusion::get_output_paddings() const {
auto value_ptr = GetAttr(kOutputPaddings);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
ActivationType Conv2dTransposeFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr));
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -29,6 +29,7 @@ void DivFusion::set_activation_type(const ActivationType &activation_type) {
ActivationType DivFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr));
}
REGISTER_PRIMITIVE_C(kNameDivFusion, DivFusion);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -23,6 +23,7 @@ namespace ops {
void EmbeddingLookupFusion::set_max_norm(const float max_norm) { (void)this->AddAttr(kMaxNorm, MakeValue(max_norm)); }
float EmbeddingLookupFusion::get_max_norm() const {
auto value_ptr = GetAttr(kMaxNorm);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr);
}
void EmbeddingLookupFusion::Init(const float max_norm) { this->set_max_norm(max_norm); }

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -37,14 +37,17 @@ void ExpFusion::set_shift(const float shift) { (void)this->AddAttr(kShift, MakeV
float ExpFusion::get_base() const {
auto value_ptr = GetAttr(kBase);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr);
}
float ExpFusion::get_scale() const {
auto value_ptr = GetAttr(kScale);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr);
}
float ExpFusion::get_shift() const {
auto value_ptr = GetAttr(kShift);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr);
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -22,13 +22,25 @@ namespace mindspore {
namespace ops {
void FullConnection::set_has_bias(const bool has_bias) { (void)this->AddAttr(kHasBias, MakeValue(has_bias)); }
bool FullConnection::get_has_bias() const { return GetValue<bool>(GetAttr(kHasBias)); }
bool FullConnection::get_has_bias() const {
auto value_ptr = GetAttr(kHasBias);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr);
}
void FullConnection::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, MakeValue(axis)); }
int64_t FullConnection::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
int64_t FullConnection::get_axis() const {
auto value_ptr = GetAttr(kAxis);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
void FullConnection::set_use_axis(const bool use_axis) { (void)this->AddAttr(kUseAxis, MakeValue(use_axis)); }
bool FullConnection::get_use_axis() const { return GetValue<bool>(GetAttr(kUseAxis)); }
bool FullConnection::get_use_axis() const {
auto value_ptr = GetAttr(kUseAxis);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr);
}
void FullConnection::set_activation_type(const ActivationType &activation_type) {
int64_t swi = activation_type;
@ -36,6 +48,7 @@ void FullConnection::set_activation_type(const ActivationType &activation_type)
}
ActivationType FullConnection::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr));
}
void FullConnection::Init(const bool has_bias, const int64_t axis, const bool use_axis,
@ -57,14 +70,16 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input1->BuildShape())[kShape];
auto prim_axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
auto has_bias = GetValue<bool>(primitive->GetAttr(kHasBias));
const int64_t input_num_bias = 3;
const int64_t input_num = 2;
if (has_bias) {
(void)CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, 3, prim_name);
(void)CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, input_num_bias, prim_name);
} else {
(void)CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, 2, prim_name);
(void)CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, input_num, prim_name);
}
auto use_axis = GetValue<bool>(primitive->GetAttr(kUseAxis));
if (use_axis && (prim_axis < 1 || prim_axis > (int64_t)input0_shape.size())) {
MS_EXCEPTION(ValueError) << "Full Connection axis invalid";
MS_EXCEPTION(ValueError) << "Full Connection axis is invalid";
}
int64_t new_k = 1;
if (use_axis) {
@ -72,7 +87,7 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
new_k *= input0_shape[t];
}
if (new_k != input1_shape[1]) {
MS_EXCEPTION(ValueError) << "Input1 size invalid";
MS_EXCEPTION(ValueError) << "Input1 size is invalid";
}
} else {
new_k = input1_shape[1];
@ -80,7 +95,7 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
if (has_bias) {
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
if (input2_shape[0] != input1_shape[0]) {
MS_EXCEPTION(ValueError) << "Bias size invalid";
MS_EXCEPTION(ValueError) << "Bias size is invalid";
}
}
std::vector<int64_t> out_shape = {(int64_t)input0_shape.size()};

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -34,6 +34,7 @@ void L2NormalizeFusion::set_activation_type(const ActivationType &activation_typ
ActivationType L2NormalizeFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr));
}
REGISTER_PRIMITIVE_C(kNameL2NormalizeFusion, L2NormalizeFusion);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -32,6 +32,7 @@ void LayerNormFusion::set_elementwise_affine(const bool elementwise_affine) {
bool LayerNormFusion::get_elementwise_affine() const {
auto value_ptr = GetAttr(kElementwiseAffine);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameLayerNormFusion, LayerNormFusion);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -40,24 +40,28 @@ void MaxPoolFusion::set_activation_type(ActivationType activation_type) {
bool MaxPoolFusion::get_global() const {
auto value_ptr = GetAttr(kGlobal);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr);
}
ActivationType MaxPoolFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr));
}
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
MS_EXCEPTION_IF_NULL(input_args[0]);
auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) {
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
}
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, 4, op_name);
const int64_t in_shape_size = 4;
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, in_shape_size, op_name);
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
auto batch = in_shape[0];
@ -84,21 +88,22 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
out_shape = {batch, out_h, out_w, channel};
}
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
MS_LOG(EXCEPTION) << "Kernel size is not valid.";
MS_LOG(EXCEPTION) << "Kernel size is invalid.";
}
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) {
MS_LOG(EXCEPTION) << "nullptr";
}
return input_args[0]->BuildType();
}
} // namespace
AbstractBasePtr MaxPoolFusionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
for (auto item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -29,6 +29,7 @@ void MulFusion::set_activation_type(const ActivationType &activation_type) {
}
ActivationType MulFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr));
}
void MulFusion::Init(const ActivationType &activation_type) { this->set_activation_type(activation_type); }

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -39,10 +39,12 @@ void PadFusion::set_constant_value(const float constant_value) {
PaddingMode PadFusion::get_padding_mode() const {
auto value_ptr = GetAttr(kPaddingMode);
MS_EXCEPTION_IF_NULL(value_ptr);
return PaddingMode(GetValue<int64_t>(value_ptr));
}
float PadFusion::get_constant_value() const {
auto value_ptr = GetAttr(kConstantValue);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr);
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -25,6 +25,7 @@ void PartialFusion::set_sub_graph_index(const int64_t sub_graph_index) {
}
int64_t PartialFusion::get_sub_graph_index() const {
auto value_ptr = GetAttr(kSubGraphIndex);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNamePartialFusion, PartialFusion);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -42,6 +42,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
@ -54,6 +55,10 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
AbstractBasePtr PowFusionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInteger("PowFusion infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
primitive->name());
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -36,10 +36,12 @@ void PReLUFusion::set_slope(const std::vector<float> &slope) { (void)this->AddAt
bool PReLUFusion::get_channel_shared() const {
auto value_ptr = GetAttr(kChannelShared);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr);
}
std::vector<float> PReLUFusion::get_slope() const {
auto value_ptr = GetAttr(kSlope);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<float>>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNamePReLUFusion, PReLUFusion);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -41,21 +41,25 @@ void ReduceFusion::set_coeff(const float coeff) { (void)this->AddAttr(kCoeff, Ma
bool ReduceFusion::get_keep_dims() const {
auto value_ptr = GetAttr(kKeepDims);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr);
}
ReduceMode ReduceFusion::get_mode() const {
auto value_ptr = GetAttr(kMode);
MS_EXCEPTION_IF_NULL(value_ptr);
return ReduceMode(GetValue<int64_t>(value_ptr));
}
bool ReduceFusion::get_reduce_to_end() const {
auto value_ptr = GetAttr(kReduceToEnd);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr);
}
float ReduceFusion::get_coeff() const {
auto value_ptr = GetAttr(kCoeff);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr);
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -32,6 +32,7 @@ void ScaleFusion::set_activation_type(const ActivationType &activation_type) {
ActivationType ScaleFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr));
}
REGISTER_PRIMITIVE_C(kNameScaleFusion, ScaleFusion);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -26,6 +26,7 @@ void SliceFusion::set_axes(const std::vector<int64_t> &axes) { (void)this->AddAt
std::vector<int64_t> SliceFusion::get_axes() const {
auto value_ptr = GetAttr(kAxes);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
@ -39,6 +40,8 @@ AbstractBasePtr SliceFusionInfer(const abstract::AnalysisEnginePtr &, const Prim
auto size_v = input_args[2]->BuildValue();
auto x_type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(x_type);
MS_EXCEPTION_IF_NULL(begin_v);
MS_EXCEPTION_IF_NULL(size_v);
auto tensor_type = x_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto data_type = tensor_type->element();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -29,6 +29,7 @@ void SubFusion::set_activation_type(const ActivationType &activation_type) {
ActivationType SubFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr));
}
REGISTER_PRIMITIVE_C(kNameSubFusion, SubFusion);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -26,6 +26,7 @@ void TileFusion::set_dims(const std::vector<int64_t> &dims) { (void)this->AddAtt
std::vector<int64_t> TileFusion::get_dims() const {
auto value_ptr = GetAttr(kDims);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameTileFusion, TileFusion);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -32,11 +32,13 @@ void TopKFusion::set_largest(const int64_t largest) { (void)this->AddAttr(kLarge
int64_t TopKFusion::get_axis() const {
auto value_ptr = GetAttr(kAxis);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
int64_t TopKFusion::get_largest() const {
auto value_ptr = GetAttr(kLargest);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameTopKFusion, TopKFusion);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -38,6 +38,7 @@ void ActivationGrad::set_activation_type(const ActivationType &type) {
ActivationType ActivationGrad::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr));
}
@ -45,6 +46,7 @@ void ActivationGrad::set_alpha(const float alpha) { (void)this->AddAttr(kAlpha,
float ActivationGrad::get_alpha() const {
auto value_ptr = GetAttr(kAlpha);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameActivationGrad, ActivationGrad);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -31,6 +31,7 @@ void BatchNormGrad::set_epsilon(const float epsilon) { (void)this->AddAttr(kEpsi
float BatchNormGrad::get_epsilon() const {
auto value_ptr = this->GetAttr(kEpsilon);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr);
}
@ -38,15 +39,19 @@ void BatchNormGrad::set_is_training(const bool is_training) { this->AddAttr(kIsT
bool BatchNormGrad::get_is_training() const {
auto value_ptr = this->GetAttr(kIsTraining);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr);
}
AbstractBasePtr BatchNormGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
MS_EXCEPTION_IF_NULL(input_args[1]);
MS_EXCEPTION_IF_NULL(input_args[2]);
MS_EXCEPTION_IF_NULL(input_args[3]);
for (auto item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const int64_t input_num = 5;
CheckAndConvertUtils::CheckInteger("BatchNormGrad infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
primitive->name());
auto y_backprop_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::Check("BatchNorm y_backprop_shape", y_backprop_shape, kEqual, "BatchNorm x_shape", x_shape);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -38,9 +38,6 @@ abstract::ShapePtr BinaryCrossEntroyGradInferShape(const PrimitivePtr &primitive
}
TypePtr BinaryCrossEntroyGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
types.emplace("x_shape", input_args[0]->BuildType());
@ -67,6 +64,13 @@ Reduction BinaryCrossEntropyGrad::get_reduction() const {
AbstractBasePtr BinaryCrossEntropyGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const int64_t input_num = 4;
CheckAndConvertUtils::CheckInteger("BinaryCrossEntropyGrad infer", SizeToLong(input_args.size()), kGreaterEqual,
input_num, primitive->name());
return std::make_shared<abstract::AbstractTensor>(BinaryCrossEntroyGradInferType(primitive, input_args),
BinaryCrossEntroyGradInferShape(primitive, input_args)->shape());
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -30,6 +30,7 @@ void BNGrad::set_eps(const float eps) { (void)this->AddAttr(kEps, MakeValue(eps)
float BNGrad::get_eps() const {
auto value_ptr = this->GetAttr(kEps);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr);
}
@ -37,6 +38,7 @@ void BNGrad::set_momentum(const float momentum) { (void)this->AddAttr(kMomentum,
float BNGrad::get_momentum() const {
auto value_ptr = this->GetAttr(kMomentum);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameBNGrad, BNGrad);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -71,6 +71,7 @@ void Conv2DBackpropFilter::set_out_channel(const int64_t out_channel) {
int64_t Conv2DBackpropFilter::get_out_channel() const {
auto value_ptr = GetAttr(kOutChannel);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
@ -80,6 +81,7 @@ void Conv2DBackpropFilter::set_kernel_size(const std::vector<int64_t> &kernel_si
std::vector<int64_t> Conv2DBackpropFilter::get_kernel_size() const {
auto value_ptr = GetAttr(kKernelSize);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
@ -90,6 +92,7 @@ void Conv2DBackpropFilter::set_pad_mode(const PadMode &pad_mode) {
PadMode Conv2DBackpropFilter::get_pad_mode() const {
auto value_ptr = GetAttr(kPadMode);
MS_EXCEPTION_IF_NULL(value_ptr);
return PadMode(GetValue<int64_t>(value_ptr));
}
@ -99,6 +102,7 @@ void Conv2DBackpropFilter::set_pad_list(const std::vector<int64_t> &pad_list) {
std::vector<int64_t> Conv2DBackpropFilter::get_pad_list() const {
auto value_ptr = GetAttr(kPadList);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
@ -106,6 +110,7 @@ void Conv2DBackpropFilter::set_mode(const int64_t mode) { (void)this->AddAttr(kM
int64_t Conv2DBackpropFilter::get_mode() const {
auto value_ptr = GetAttr(kMode);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
@ -113,6 +118,7 @@ void Conv2DBackpropFilter::set_stride(const std::vector<int64_t> &stride) { this
std::vector<int64_t> Conv2DBackpropFilter::get_stride() const {
auto value_ptr = GetAttr(kStride);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
@ -122,6 +128,7 @@ void Conv2DBackpropFilter::set_dilation(const std::vector<int64_t> &dilation) {
std::vector<int64_t> Conv2DBackpropFilter::get_dilation() const {
auto value_ptr = GetAttr(kDilation);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
@ -129,6 +136,7 @@ void Conv2DBackpropFilter::set_group(const int64_t group) { (void)this->AddAttr(
int64_t Conv2DBackpropFilter::get_group() const {
auto value_ptr = GetAttr(kGroup);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
@ -139,6 +147,7 @@ void Conv2DBackpropFilter::set_format(const Format &format) {
Format Conv2DBackpropFilter::get_format() const {
auto value_ptr = GetAttr(kFormat);
MS_EXCEPTION_IF_NULL(value_ptr);
return Format(GetValue<int64_t>(value_ptr));
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -172,51 +172,61 @@ void Conv2DBackpropInput::set_pad_list(const std::vector<int64_t> &pad_list) {
int64_t Conv2DBackpropInput::get_out_channel() const {
auto value_ptr = GetAttr(kOutChannel);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
std::vector<int64_t> Conv2DBackpropInput::get_kernel_size() const {
auto value_ptr = GetAttr(kKernelSize);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::vector<int64_t> Conv2DBackpropInput::get_stride() const {
auto value_ptr = GetAttr(kStride);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::vector<int64_t> Conv2DBackpropInput::get_dilation() const {
auto value_ptr = GetAttr(kDilation);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
PadMode Conv2DBackpropInput::get_pad_mode() const {
auto value_ptr = GetAttr(kPadMode);
MS_EXCEPTION_IF_NULL(value_ptr);
return PadMode(GetValue<int64_t>(value_ptr));
}
std::vector<int64_t> Conv2DBackpropInput::get_pad() const {
auto value_ptr = GetAttr(kPad);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
int64_t Conv2DBackpropInput::get_mode() const {
auto value_ptr = GetAttr(kMode);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
int64_t Conv2DBackpropInput::get_group() const {
auto value_ptr = GetAttr(kGroup);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
Format Conv2DBackpropInput::get_format() const {
auto value_ptr = GetAttr(kFormat);
MS_EXCEPTION_IF_NULL(value_ptr);
return Format(GetValue<int64_t>(value_ptr));
}
std::vector<int64_t> Conv2DBackpropInput::get_pad_list() const {
auto value_ptr = GetAttr(kPadList);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropInput, prim::kPrimConv2DBackpropInput, Conv2DBackpropInputInfer, nullptr,

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -45,6 +45,7 @@ void DeConv2DGradFilter::set_in_channel(const int64_t in_channel) {
int64_t DeConv2DGradFilter::get_in_channel() const {
auto value_ptr = GetAttr(kInChannel);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
@ -54,6 +55,7 @@ void DeConv2DGradFilter::set_out_channel(const int64_t out_channel) {
int64_t DeConv2DGradFilter::get_out_channel() const {
auto value_ptr = GetAttr(kOutChannel);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
@ -63,6 +65,7 @@ void DeConv2DGradFilter::set_kernel_size(const std::vector<int64_t> &kernel_size
std::vector<int64_t> DeConv2DGradFilter::get_kernel_size() const {
auto value_ptr = GetAttr(kKernelSize);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
@ -73,6 +76,7 @@ void DeConv2DGradFilter::set_pad_mode(const PadMode &pad_mode) {
PadMode DeConv2DGradFilter::get_pad_mode() const {
auto value_ptr = GetAttr(kPadMode);
MS_EXCEPTION_IF_NULL(value_ptr);
return PadMode(GetValue<int64_t>(value_ptr));
}
@ -82,6 +86,7 @@ void DeConv2DGradFilter::set_pad_list(const std::vector<int64_t> &pad_list) {
std::vector<int64_t> DeConv2DGradFilter::get_pad_list() const {
auto value_ptr = GetAttr(kPadList);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
@ -91,6 +96,7 @@ void DeConv2DGradFilter::set_stride(const std::vector<int64_t> &stride) {
std::vector<int64_t> DeConv2DGradFilter::get_stride() const {
auto value_ptr = GetAttr(kStride);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
@ -100,6 +106,7 @@ void DeConv2DGradFilter::set_dilation(const std::vector<int64_t> &dilation) {
std::vector<int64_t> DeConv2DGradFilter::get_dilation() const {
auto value_ptr = GetAttr(kDilation);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
@ -107,6 +114,7 @@ void DeConv2DGradFilter::set_group(const int64_t group) { (void)this->AddAttr(kG
int64_t DeConv2DGradFilter::get_group() const {
auto value_ptr = GetAttr(kGroup);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
@ -117,6 +125,7 @@ void DeConv2DGradFilter::set_format(const Format &format) {
Format DeConv2DGradFilter::get_format() const {
auto value_ptr = GetAttr(kFormat);
MS_EXCEPTION_IF_NULL(value_ptr);
return Format(GetValue<int64_t>(value_ptr));
}
@ -127,6 +136,7 @@ void DeConv2DGradFilter::set_activation_type(const ActivationType &activation_ty
ActivationType DeConv2DGradFilter::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr));
}
@ -134,6 +144,7 @@ void DeConv2DGradFilter::set_has_bias(const bool has_bias) { (void)this->AddAttr
bool DeConv2DGradFilter::get_has_bias() const {
auto value_ptr = GetAttr(kHasBias);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameDeConv2DGradFilter, DeConv2DGradFilter);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -28,6 +28,7 @@ void DropoutGrad::set_keep_prob(const float keep_prob) {
float DropoutGrad::get_keep_prob() const {
auto value_ptr = GetAttr(kKeepProb);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr);
}
@ -55,6 +56,13 @@ TypePtr DropoutGradInferType(const PrimitivePtr &prim, const std::vector<Abstrac
AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
for (auto item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInteger("DropoutGrad infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
primitive->name());
return std::make_shared<abstract::AbstractTensor>(DropoutGradInferType(primitive, input_args),
DropoutGradInferShape(primitive, input_args)->shape());
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -22,7 +22,8 @@ AbstractBasePtr FlattenGradInfer(const abstract::AnalysisEnginePtr &, const Prim
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim_name);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -46,6 +46,7 @@ void GroupConv2DGradInput::set_in_channel(const int64_t &in_channel) {
int64_t GroupConv2DGradInput::get_in_channel() const {
auto value_ptr = GetAttr(kInChannel);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
@ -55,6 +56,7 @@ void GroupConv2DGradInput::set_out_channel(const int64_t &out_channel) {
int64_t GroupConv2DGradInput::get_out_channel() const {
auto value_ptr = GetAttr(kOutChannel);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
@ -64,6 +66,7 @@ void GroupConv2DGradInput::set_kernel_size(const std::vector<int64_t> &kernel_si
std::vector<int64_t> GroupConv2DGradInput::get_kernel_size() const {
auto value_ptr = GetAttr(kKernelSize);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
@ -74,6 +77,7 @@ void GroupConv2DGradInput::set_pad_mode(const PadMode &pad_mode) {
PadMode GroupConv2DGradInput::get_pad_mode() const {
auto value_ptr = GetAttr(kPadMode);
MS_EXCEPTION_IF_NULL(value_ptr);
return PadMode(GetValue<int64_t>(value_ptr));
}
@ -83,6 +87,7 @@ void GroupConv2DGradInput::set_pad_list(const std::vector<int64_t> &pad_list) {
std::vector<int64_t> GroupConv2DGradInput::get_pad_list() const {
auto value_ptr = GetAttr(kPadList);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
@ -90,6 +95,7 @@ void GroupConv2DGradInput::set_stride(const std::vector<int64_t> &stride) { this
std::vector<int64_t> GroupConv2DGradInput::get_stride() const {
auto value_ptr = GetAttr(kStride);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
@ -99,6 +105,7 @@ void GroupConv2DGradInput::set_dilation(const std::vector<int64_t> &dilation) {
std::vector<int64_t> GroupConv2DGradInput::get_dilation() const {
auto value_ptr = GetAttr(kDilation);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
@ -106,6 +113,7 @@ void GroupConv2DGradInput::set_group(const int64_t &group) { (void)this->AddAttr
int64_t GroupConv2DGradInput::get_group() const {
auto value_ptr = GetAttr(kGroup);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
@ -114,7 +122,9 @@ void GroupConv2DGradInput::set_input_shape(const std::vector<int64_t> &input_sha
}
std::vector<int64_t> GroupConv2DGradInput::get_input_shape() const {
return GetValue<std::vector<int64_t>>(GetAttr(kInputShape));
auto value_ptr = GetAttr(kInputShape);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
void GroupConv2DGradInput::set_format(const Format &format) {
@ -124,6 +134,7 @@ void GroupConv2DGradInput::set_format(const Format &format) {
Format GroupConv2DGradInput::get_format() const {
auto value_ptr = GetAttr(kFormat);
MS_EXCEPTION_IF_NULL(value_ptr);
return Format(GetValue<int64_t>(value_ptr));
}
@ -134,6 +145,7 @@ void GroupConv2DGradInput::set_activation_type(const ActivationType &activation_
ActivationType GroupConv2DGradInput::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr));
}
@ -141,6 +153,7 @@ void GroupConv2DGradInput::set_has_bias(const bool has_bias) { (void)this->AddAt
bool GroupConv2DGradInput::get_has_bias() const {
auto value_ptr = GetAttr(kHasBias);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr);
}
AbstractBasePtr GroupConv2DGradInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
@ -152,11 +165,16 @@ AbstractBasePtr GroupConv2DGradInputInfer(const abstract::AnalysisEnginePtr &, c
MS_EXCEPTION_IF_NULL(input_args[0]);
// Infer shape
auto shape = GetValue<std::vector<int64_t>>(primitive->GetAttr(kInputShape));
auto shape_ptr = primitive->GetAttr(kInputShape);
MS_EXCEPTION_IF_NULL(shape_ptr);
auto shape = GetValue<std::vector<int64_t>>(shape_ptr);
// Infer type
auto type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
auto type_ptr = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(type_ptr);
auto type_tensor_ptr = type_ptr->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(type_tensor_ptr);
auto type = type_tensor_ptr->element();
return std::make_shared<abstract::AbstractTensor>(type, shape);
}
REGISTER_PRIMITIVE_C(kNameGroupConv2DGradInput, GroupConv2DGradInput);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -48,10 +48,12 @@ void LayerNormGrad::set_begin_params_axis(const int64_t begin_params_axis) {
}
int64_t LayerNormGrad::get_begin_norm_axis() const {
auto value_ptr = this->GetAttr(kBeginNormAxis);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
int64_t LayerNormGrad::get_begin_params_axis() const {
auto value_ptr = this->GetAttr(kBeginParamsAxis);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
REGISTER_PRIMITIVE_EVAL_IMPL(LayerNormGrad, prim::kPrimLayerNormGrad, LayerNormGradInfer, nullptr, true);

Some files were not shown because too many files have changed in this diff Show More