!20912 code check for ops

Merge pull request !20912 from Simson/codex
This commit is contained in:
i-robot 2021-07-29 02:21:09 +00:00 committed by Gitee
commit ad589e6780
127 changed files with 386 additions and 162 deletions

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -39,6 +39,7 @@ class RecvGpuKernel : public GpuKernel {
return true; return true;
} }
bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_node_ = kernel_node; kernel_node_ = kernel_node;
wait_stream_ = reinterpret_cast<cudaStream_t>(GetAttr<uintptr_t>(kernel_node, "wait_event_stream")); wait_stream_ = reinterpret_cast<cudaStream_t>(GetAttr<uintptr_t>(kernel_node, "wait_event_stream"));
wait_event_ = reinterpret_cast<cudaEvent_t>(GetAttr<uintptr_t>(kernel_node, "wait_event")); wait_event_ = reinterpret_cast<cudaEvent_t>(GetAttr<uintptr_t>(kernel_node, "wait_event"));

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -39,6 +39,7 @@ class SendGpuKernel : public GpuKernel {
return true; return true;
} }
bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_node_ = kernel_node; kernel_node_ = kernel_node;
record_stream_ = reinterpret_cast<cudaStream_t>(GetAttr<uintptr_t>(kernel_node, "record_event_stream")); record_stream_ = reinterpret_cast<cudaStream_t>(GetAttr<uintptr_t>(kernel_node, "record_event_stream"));
record_event_ = reinterpret_cast<cudaEvent_t>(GetAttr<uintptr_t>(kernel_node, "record_event")); record_event_ = reinterpret_cast<cudaEvent_t>(GetAttr<uintptr_t>(kernel_node, "record_event"));

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -34,11 +34,17 @@ const std::vector<size_t> &DatasetInitKernel::GetOutputSizeList() const { return
const std::vector<size_t> &DatasetInitKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } const std::vector<size_t> &DatasetInitKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool DatasetInitKernel::Init(const CNodePtr &kernel_node) { bool DatasetInitKernel::Init(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
queue_name_ = GetAttr<std::string>(kernel_node, "queue_name"); queue_name_ = GetAttr<std::string>(kernel_node, "queue_name");
std::vector<std::vector<int>> shapes; std::vector<std::vector<int>> shapes;
std::vector<TypePtr> types; std::vector<TypePtr> types;
GetShapeAndType(kernel_node, &shapes, &types); GetShapeAndType(kernel_node, &shapes, &types);
for (auto item : types) {
MS_EXCEPTION_IF_NULL(item);
}
if (types.size() < shapes.size()) {
MS_LOG(EXCEPTION) << "types size is less than shapes size.";
}
for (size_t i = 0; i < shapes.size(); i++) { for (size_t i = 0; i < shapes.size(); i++) {
int unit = UnitSizeInBytes(types[i]->type_id()); int unit = UnitSizeInBytes(types[i]->type_id());
int nums = ElementNums(shapes[i]); int nums = ElementNums(shapes[i]);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -47,12 +47,18 @@ const std::vector<size_t> &DatasetIteratorKernel::GetOutputSizeList() const { re
const std::vector<size_t> &DatasetIteratorKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } const std::vector<size_t> &DatasetIteratorKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool DatasetIteratorKernel::Init(const CNodePtr &kernel_node) { bool DatasetIteratorKernel::Init(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_node_ = kernel_node; kernel_node_ = kernel_node;
queue_name_ = GetAttr<std::string>(kernel_node, "shared_name"); queue_name_ = GetAttr<std::string>(kernel_node, "shared_name");
std::vector<std::vector<int>> shapes; std::vector<std::vector<int>> shapes;
std::vector<TypePtr> types; std::vector<TypePtr> types;
GetShapeAndType(kernel_node, &shapes, &types); GetShapeAndType(kernel_node, &shapes, &types);
for (auto item : types) {
MS_EXCEPTION_IF_NULL(item);
}
if (types.size() < shapes.size()) {
MS_LOG(EXCEPTION) << "types size is less than shapes size.";
}
for (size_t i = 0; i < shapes.size(); i++) { for (size_t i = 0; i < shapes.size(); i++) {
int unit = UnitSizeInBytes(types[i]->type_id()); int unit = UnitSizeInBytes(types[i]->type_id());
int nums = ElementNums(shapes[i]); int nums = ElementNums(shapes[i]);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -48,6 +48,9 @@ void GetNextProfiling::SaveProfilingData() {
return; return;
} }
for (uint32_t index = 0; index < queue_size_.size(); index++) { for (uint32_t index = 0; index < queue_size_.size(); index++) {
if (index > time_stamp_.size() - 1) {
MS_LOG(EXCEPTION) << "index exceeds time_stamp_ size.";
}
handle << Name() << " " << time_stamp_[index].first << " " << time_stamp_[index].second << " " << queue_size_[index] handle << Name() << " " << time_stamp_[index].first << " " << time_stamp_[index].second << " " << queue_size_[index]
<< std::endl; << std::endl;
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -68,6 +68,7 @@ int ElementNums(const std::vector<int> &shape) {
} }
void GetShapeAndType(const CNodePtr &kernel_node, std::vector<std::vector<int>> *shapes, std::vector<TypePtr> *types) { void GetShapeAndType(const CNodePtr &kernel_node, std::vector<std::vector<int>> *shapes, std::vector<TypePtr> *types) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(shapes); MS_EXCEPTION_IF_NULL(shapes);
MS_EXCEPTION_IF_NULL(types); MS_EXCEPTION_IF_NULL(types);
std::vector<std::vector<int64_t>> shapes_me = std::vector<std::vector<int64_t>> shapes_me =

View File

@ -87,6 +87,7 @@ class PrintGpuKernel : public GpuKernel {
} }
bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_node_ = kernel_node; kernel_node_ = kernel_node;
if (AnfAlgo::HasNodeAttr("string_pos", kernel_node)) { if (AnfAlgo::HasNodeAttr("string_pos", kernel_node)) {
string_value_ = GetAttr<std::vector<std::string>>(kernel_node, "string_value"); string_value_ = GetAttr<std::vector<std::string>>(kernel_node, "string_value");

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -77,6 +77,7 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel {
} }
bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_node_ = kernel_node; kernel_node_ = kernel_node;
nccl_data_type_ = nccl_dtype(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)); nccl_data_type_ = nccl_dtype(AnfAlgo::GetInputDeviceDataType(kernel_node, 0));
InferCommType(kernel_node); InferCommType(kernel_node);

View File

@ -49,6 +49,7 @@ class AssignGpuKernel : public GpuKernel {
} }
bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_node_ = kernel_node; kernel_node_ = kernel_node;
if (!CheckParam(kernel_node)) { if (!CheckParam(kernel_node)) {
return false; return false;
@ -71,6 +72,7 @@ class AssignGpuKernel : public GpuKernel {
private: private:
bool CheckParam(const CNodePtr &kernel_node) { bool CheckParam(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) { if (input_num != 2) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but AssignGpuKernel needs 2 output."; MS_LOG(ERROR) << "Input number is " << input_num << ", but AssignGpuKernel needs 2 output.";

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -62,6 +62,7 @@ class BoundingBoxDecodeGpuKernel : public GpuKernel {
} }
bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) { if (input_num != 2) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but BoundingBoxDecode needs 2 inputs."; MS_LOG(ERROR) << "Input number is " << input_num << ", but BoundingBoxDecode needs 2 inputs.";
@ -89,10 +90,11 @@ class BoundingBoxDecodeGpuKernel : public GpuKernel {
InitSizeLists(); InitSizeLists();
const size_t coordinate_size = 4; const size_t coordinate_size = 4;
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueTuple>() || auto means = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means");
AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueList>()) { MS_EXCEPTION_IF_NULL(means);
if (means->isa<ValueTuple>() || means->isa<ValueList>()) {
means_ = GetAttr<std::vector<float>>(kernel_node, "means"); means_ = GetAttr<std::vector<float>>(kernel_node, "means");
} else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<FloatImm>()) { } else if (means->isa<FloatImm>()) {
float mean = GetAttr<float>(kernel_node, "means"); float mean = GetAttr<float>(kernel_node, "means");
for (size_t i = 0; i < coordinate_size; i++) { for (size_t i = 0; i < coordinate_size; i++) {
means_.emplace_back(mean); means_.emplace_back(mean);
@ -101,10 +103,11 @@ class BoundingBoxDecodeGpuKernel : public GpuKernel {
MS_LOG(EXCEPTION) << "Attribute means type is invalid."; MS_LOG(EXCEPTION) << "Attribute means type is invalid.";
} }
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueTuple>() || auto stds = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds");
AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueList>()) { MS_EXCEPTION_IF_NULL(stds);
if (stds->isa<ValueTuple>() || stds->isa<ValueList>()) {
stds_ = GetAttr<std::vector<float>>(kernel_node, "stds"); stds_ = GetAttr<std::vector<float>>(kernel_node, "stds");
} else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<FloatImm>()) { } else if (stds->isa<FloatImm>()) {
float std = GetAttr<float>(kernel_node, "stds"); float std = GetAttr<float>(kernel_node, "stds");
for (size_t i = 0; i < coordinate_size; i++) { for (size_t i = 0; i < coordinate_size; i++) {
stds_.emplace_back(std); stds_.emplace_back(std);

View File

@ -61,6 +61,7 @@ class BoundingBoxEncodeGpuKernel : public GpuKernel {
} }
bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) { if (input_num != 2) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but BoundingBoxEncode needs 2 inputs."; MS_LOG(ERROR) << "Input number is " << input_num << ", but BoundingBoxEncode needs 2 inputs.";
@ -88,10 +89,11 @@ class BoundingBoxEncodeGpuKernel : public GpuKernel {
InitSizeLists(); InitSizeLists();
const size_t coordinate_size = 4; const size_t coordinate_size = 4;
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueTuple>() || auto means = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means");
AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueList>()) { MS_EXCEPTION_IF_NULL(means);
if (means->isa<ValueTuple>() || means->isa<ValueList>()) {
means_ = GetAttr<std::vector<float>>(kernel_node, "means"); means_ = GetAttr<std::vector<float>>(kernel_node, "means");
} else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<FloatImm>()) { } else if (means->isa<FloatImm>()) {
float mean = GetAttr<float>(kernel_node, "means"); float mean = GetAttr<float>(kernel_node, "means");
for (size_t i = 0; i < coordinate_size; i++) { for (size_t i = 0; i < coordinate_size; i++) {
means_.emplace_back(mean); means_.emplace_back(mean);
@ -99,11 +101,11 @@ class BoundingBoxEncodeGpuKernel : public GpuKernel {
} else { } else {
MS_LOG(EXCEPTION) << "Attribute means type is invalid."; MS_LOG(EXCEPTION) << "Attribute means type is invalid.";
} }
auto stds = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds");
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueTuple>() || MS_EXCEPTION_IF_NULL(stds);
AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueList>()) { if (stds->isa<ValueTuple>() || stds->isa<ValueList>()) {
stds_ = GetAttr<std::vector<float>>(kernel_node, "stds"); stds_ = GetAttr<std::vector<float>>(kernel_node, "stds");
} else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<FloatImm>()) { } else if (stds->isa<FloatImm>()) {
float std = GetAttr<float>(kernel_node, "stds"); float std = GetAttr<float>(kernel_node, "stds");
for (size_t i = 0; i < coordinate_size; i++) { for (size_t i = 0; i < coordinate_size; i++) {
stds_.emplace_back(std); stds_.emplace_back(std);

View File

@ -55,6 +55,7 @@ class CheckValidGpuKernel : public GpuKernel {
} }
bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) { if (input_num != 2) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but CheckValid needs 2 inputs."; MS_LOG(ERROR) << "Input number is " << input_num << ", but CheckValid needs 2 inputs.";

View File

@ -59,6 +59,7 @@ class GpuConvertToDynamicShapeGpuKernel : public GpuKernel {
} }
bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_node_ = kernel_node; kernel_node_ = kernel_node;
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_count != 1) { if (input_count != 1) {

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -59,6 +59,7 @@ class IOUGpuKernel : public GpuKernel {
} }
bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) { if (input_num != 2) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but IOU needs 2 inputs."; MS_LOG(ERROR) << "Input number is " << input_num << ", but IOU needs 2 inputs.";

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -89,6 +89,7 @@ class RandomCategoricalGpuKernel : public GpuKernel {
} }
bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_node_ = kernel_node; kernel_node_ = kernel_node;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) { if (input_num != 3) {

View File

@ -70,6 +70,7 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel {
} }
bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
uint32_t time_interval = std::chrono::system_clock::now().time_since_epoch().count(); uint32_t time_interval = std::chrono::system_clock::now().time_since_epoch().count();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) { if (input_num != 1) {

View File

@ -76,6 +76,7 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
} }
bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_node_ = kernel_node; kernel_node_ = kernel_node;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) { if (input_num != 1) {
@ -173,6 +174,7 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
MS_LOG(EXCEPTION) << "range_max_ failed to cast"; MS_LOG(EXCEPTION) << "range_max_ failed to cast";
} }
range = static_cast<S>(range_max_); range = static_cast<S>(range_max_);
MS_EXCEPTION_IF_ZERO("range", range);
return static_cast<S>(1.0f / range); return static_cast<S>(1.0f / range);
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -42,6 +42,7 @@ void AdderFusion::set_activation_type(const ActivationType activation_type) {
ActivationType AdderFusion::get_activation_type() const { ActivationType AdderFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType); auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr)); return ActivationType(GetValue<int64_t>(value_ptr));
} }
REGISTER_PRIMITIVE_C(kNameAdderFusion, AdderFusion); REGISTER_PRIMITIVE_C(kNameAdderFusion, AdderFusion);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -31,9 +31,21 @@ void ArgMaxFusion::set_out_max_value(const bool out_max_value) {
} }
void ArgMaxFusion::set_top_k(const int64_t top_k) { (void)this->AddAttr(kTopK, MakeValue(top_k)); } void ArgMaxFusion::set_top_k(const int64_t top_k) { (void)this->AddAttr(kTopK, MakeValue(top_k)); }
bool ArgMaxFusion::get_keep_dims() const { return GetValue<bool>(GetAttr(kKeepDims)); } bool ArgMaxFusion::get_keep_dims() const {
bool ArgMaxFusion::get_out_max_value() const { return GetValue<bool>(GetAttr(kOutMaxValue)); } auto keep_dims = GetAttr(kKeepDims);
int64_t ArgMaxFusion::get_top_k() const { return GetValue<int64_t>(GetAttr(kTopK)); } MS_EXCEPTION_IF_NULL(keep_dims);
return GetValue<bool>(keep_dims);
}
bool ArgMaxFusion::get_out_max_value() const {
auto out_maxv = GetAttr(kOutMaxValue);
MS_EXCEPTION_IF_NULL(out_maxv);
return GetValue<bool>(out_maxv);
}
int64_t ArgMaxFusion::get_top_k() const {
auto topk = GetAttr(kTopK);
MS_EXCEPTION_IF_NULL(topk);
return GetValue<int64_t>(topk);
}
REGISTER_PRIMITIVE_C(kNameArgMaxFusion, ArgMaxFusion); REGISTER_PRIMITIVE_C(kNameArgMaxFusion, ArgMaxFusion);
} // namespace ops } // namespace ops

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -29,14 +29,21 @@ void ArgMinFusion::set_keep_dims(const bool keep_dims) { (void)this->AddAttr(kKe
void ArgMinFusion::set_out_max_value(bool out_max_value) { (void)AddAttr(kOutMaxValue, MakeValue(out_max_value)); } void ArgMinFusion::set_out_max_value(bool out_max_value) { (void)AddAttr(kOutMaxValue, MakeValue(out_max_value)); }
void ArgMinFusion::set_top_k(int64_t top_k) { (void)this->AddAttr(kTopK, MakeValue(top_k)); } void ArgMinFusion::set_top_k(int64_t top_k) { (void)this->AddAttr(kTopK, MakeValue(top_k)); }
bool ArgMinFusion::get_keep_dims() const { return GetValue<bool>(GetAttr(kKeepDims)); } bool ArgMinFusion::get_keep_dims() const {
auto keep_dims = GetAttr(kKeepDims);
MS_EXCEPTION_IF_NULL(keep_dims);
return GetValue<bool>(keep_dims);
}
bool ArgMinFusion::get_out_max_value() const { bool ArgMinFusion::get_out_max_value() const {
auto value_ptr = GetAttr(kOutMaxValue); auto value_ptr = GetAttr(kOutMaxValue);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr); return GetValue<bool>(value_ptr);
} }
int64_t ArgMinFusion::get_top_k() const { int64_t ArgMinFusion::get_top_k() const {
auto value_ptr = GetAttr(kTopK); auto value_ptr = GetAttr(kTopK);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr); return GetValue<int64_t>(value_ptr);
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -40,17 +40,22 @@ void AvgPoolFusion::set_activation_type(ActivationType activation_type) {
bool AvgPoolFusion::get_global() const { bool AvgPoolFusion::get_global() const {
auto value_ptr = GetAttr(kGlobal); auto value_ptr = GetAttr(kGlobal);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr); return GetValue<bool>(value_ptr);
} }
ActivationType AvgPoolFusion::get_activation_type() const { ActivationType AvgPoolFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType); auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr)); return ActivationType(GetValue<int64_t>(value_ptr));
} }
namespace { namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
for (auto item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto op_name = primitive->name(); auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
@ -66,6 +71,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto in_w = in_shape[3]; auto in_w = in_shape[3];
auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides)); auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
(void)CheckAndConvertUtils::CheckPositiveVector(kStride, strides, op_name);
auto kernel_h = kernel_size[2]; auto kernel_h = kernel_size[2];
auto kernel_w = kernel_size[3]; auto kernel_w = kernel_size[3];
auto stride_h = strides[2]; auto stride_h = strides[2];
@ -90,8 +96,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
} }
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr arg) { return arg == nullptr; })) { for (auto item : input_args) {
MS_LOG(EXCEPTION) << "nullptr"; MS_EXCEPTION_IF_NULL(item);
} }
return input_args[0]->BuildType(); return input_args[0]->BuildType();
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -52,11 +52,13 @@ void Conv2DBackpropFilterFusion::set_in_channel(const int64_t in_channel) {
ActivationType Conv2DBackpropFilterFusion::get_activation_type() const { ActivationType Conv2DBackpropFilterFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType); auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr)); return ActivationType(GetValue<int64_t>(value_ptr));
} }
int64_t Conv2DBackpropFilterFusion::get_in_channel() const { int64_t Conv2DBackpropFilterFusion::get_in_channel() const {
auto value_ptr = GetAttr(kInChannel); auto value_ptr = GetAttr(kInChannel);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr); return GetValue<int64_t>(value_ptr);
} }
REGISTER_PRIMITIVE_C(kNameConv2DBackpropFilterFusion, Conv2DBackpropFilterFusion); REGISTER_PRIMITIVE_C(kNameConv2DBackpropFilterFusion, Conv2DBackpropFilterFusion);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -51,11 +51,13 @@ void Conv2DBackpropInputFusion::set_activation_type(const ActivationType &activa
} }
int64_t Conv2DBackpropInputFusion::get_in_channel() const { int64_t Conv2DBackpropInputFusion::get_in_channel() const {
auto value_ptr = GetAttr(kInChannel); auto value_ptr = GetAttr(kInChannel);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr); return GetValue<int64_t>(value_ptr);
} }
ActivationType Conv2DBackpropInputFusion::get_activation_type() const { ActivationType Conv2DBackpropInputFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType); auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr)); return ActivationType(GetValue<int64_t>(value_ptr));
} }
REGISTER_PRIMITIVE_C(kNameConv2DBackpropInputFusion, Conv2DBackpropInputFusion); REGISTER_PRIMITIVE_C(kNameConv2DBackpropInputFusion, Conv2DBackpropInputFusion);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -48,14 +48,17 @@ void Conv2DFusion::set_activation_type(const ActivationType &activation_type) {
} }
int64_t Conv2DFusion::get_in_channel() const { int64_t Conv2DFusion::get_in_channel() const {
auto value_ptr = GetAttr(kInChannel); auto value_ptr = GetAttr(kInChannel);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr); return GetValue<int64_t>(value_ptr);
} }
std::vector<int64_t> Conv2DFusion::get_pad_list() const { std::vector<int64_t> Conv2DFusion::get_pad_list() const {
auto value_ptr = GetAttr(kPadList); auto value_ptr = GetAttr(kPadList);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
ActivationType Conv2DFusion::get_activation_type() const { ActivationType Conv2DFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType); auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr)); return ActivationType(GetValue<int64_t>(value_ptr));
} }
REGISTER_PRIMITIVE_C(kNameConv2DFusion, Conv2DFusion); REGISTER_PRIMITIVE_C(kNameConv2DFusion, Conv2DFusion);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -71,11 +71,13 @@ void Conv2dTransposeFusion::set_activation_type(ActivationType activation_type)
std::vector<int64_t> Conv2dTransposeFusion::get_output_paddings() const { std::vector<int64_t> Conv2dTransposeFusion::get_output_paddings() const {
auto value_ptr = GetAttr(kOutputPaddings); auto value_ptr = GetAttr(kOutputPaddings);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
ActivationType Conv2dTransposeFusion::get_activation_type() const { ActivationType Conv2dTransposeFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType); auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr)); return ActivationType(GetValue<int64_t>(value_ptr));
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -29,6 +29,7 @@ void DivFusion::set_activation_type(const ActivationType &activation_type) {
ActivationType DivFusion::get_activation_type() const { ActivationType DivFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType); auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr)); return ActivationType(GetValue<int64_t>(value_ptr));
} }
REGISTER_PRIMITIVE_C(kNameDivFusion, DivFusion); REGISTER_PRIMITIVE_C(kNameDivFusion, DivFusion);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -23,6 +23,7 @@ namespace ops {
void EmbeddingLookupFusion::set_max_norm(const float max_norm) { (void)this->AddAttr(kMaxNorm, MakeValue(max_norm)); } void EmbeddingLookupFusion::set_max_norm(const float max_norm) { (void)this->AddAttr(kMaxNorm, MakeValue(max_norm)); }
float EmbeddingLookupFusion::get_max_norm() const { float EmbeddingLookupFusion::get_max_norm() const {
auto value_ptr = GetAttr(kMaxNorm); auto value_ptr = GetAttr(kMaxNorm);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr); return GetValue<float>(value_ptr);
} }
void EmbeddingLookupFusion::Init(const float max_norm) { this->set_max_norm(max_norm); } void EmbeddingLookupFusion::Init(const float max_norm) { this->set_max_norm(max_norm); }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -37,14 +37,17 @@ void ExpFusion::set_shift(const float shift) { (void)this->AddAttr(kShift, MakeV
float ExpFusion::get_base() const { float ExpFusion::get_base() const {
auto value_ptr = GetAttr(kBase); auto value_ptr = GetAttr(kBase);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr); return GetValue<float>(value_ptr);
} }
float ExpFusion::get_scale() const { float ExpFusion::get_scale() const {
auto value_ptr = GetAttr(kScale); auto value_ptr = GetAttr(kScale);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr); return GetValue<float>(value_ptr);
} }
float ExpFusion::get_shift() const { float ExpFusion::get_shift() const {
auto value_ptr = GetAttr(kShift); auto value_ptr = GetAttr(kShift);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr); return GetValue<float>(value_ptr);
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -22,13 +22,25 @@ namespace mindspore {
namespace ops { namespace ops {
void FullConnection::set_has_bias(const bool has_bias) { (void)this->AddAttr(kHasBias, MakeValue(has_bias)); } void FullConnection::set_has_bias(const bool has_bias) { (void)this->AddAttr(kHasBias, MakeValue(has_bias)); }
bool FullConnection::get_has_bias() const { return GetValue<bool>(GetAttr(kHasBias)); } bool FullConnection::get_has_bias() const {
auto value_ptr = GetAttr(kHasBias);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr);
}
void FullConnection::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, MakeValue(axis)); } void FullConnection::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, MakeValue(axis)); }
int64_t FullConnection::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); } int64_t FullConnection::get_axis() const {
auto value_ptr = GetAttr(kAxis);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
void FullConnection::set_use_axis(const bool use_axis) { (void)this->AddAttr(kUseAxis, MakeValue(use_axis)); } void FullConnection::set_use_axis(const bool use_axis) { (void)this->AddAttr(kUseAxis, MakeValue(use_axis)); }
bool FullConnection::get_use_axis() const { return GetValue<bool>(GetAttr(kUseAxis)); } bool FullConnection::get_use_axis() const {
auto value_ptr = GetAttr(kUseAxis);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr);
}
void FullConnection::set_activation_type(const ActivationType &activation_type) { void FullConnection::set_activation_type(const ActivationType &activation_type) {
int64_t swi = activation_type; int64_t swi = activation_type;
@ -36,6 +48,7 @@ void FullConnection::set_activation_type(const ActivationType &activation_type)
} }
ActivationType FullConnection::get_activation_type() const { ActivationType FullConnection::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType); auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr)); return ActivationType(GetValue<int64_t>(value_ptr));
} }
void FullConnection::Init(const bool has_bias, const int64_t axis, const bool use_axis, void FullConnection::Init(const bool has_bias, const int64_t axis, const bool use_axis,
@ -57,14 +70,16 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input1->BuildShape())[kShape]; auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input1->BuildShape())[kShape];
auto prim_axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); auto prim_axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
auto has_bias = GetValue<bool>(primitive->GetAttr(kHasBias)); auto has_bias = GetValue<bool>(primitive->GetAttr(kHasBias));
const int64_t input_num_bias = 3;
const int64_t input_num = 2;
if (has_bias) { if (has_bias) {
(void)CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, 3, prim_name); (void)CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, input_num_bias, prim_name);
} else { } else {
(void)CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, 2, prim_name); (void)CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, input_num, prim_name);
} }
auto use_axis = GetValue<bool>(primitive->GetAttr(kUseAxis)); auto use_axis = GetValue<bool>(primitive->GetAttr(kUseAxis));
if (use_axis && (prim_axis < 1 || prim_axis > (int64_t)input0_shape.size())) { if (use_axis && (prim_axis < 1 || prim_axis > (int64_t)input0_shape.size())) {
MS_EXCEPTION(ValueError) << "Full Connection axis invalid"; MS_EXCEPTION(ValueError) << "Full Connection axis is invalid";
} }
int64_t new_k = 1; int64_t new_k = 1;
if (use_axis) { if (use_axis) {
@ -72,7 +87,7 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
new_k *= input0_shape[t]; new_k *= input0_shape[t];
} }
if (new_k != input1_shape[1]) { if (new_k != input1_shape[1]) {
MS_EXCEPTION(ValueError) << "Input1 size invalid"; MS_EXCEPTION(ValueError) << "Input1 size is invalid";
} }
} else { } else {
new_k = input1_shape[1]; new_k = input1_shape[1];
@ -80,7 +95,7 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
if (has_bias) { if (has_bias) {
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
if (input2_shape[0] != input1_shape[0]) { if (input2_shape[0] != input1_shape[0]) {
MS_EXCEPTION(ValueError) << "Bias size invalid"; MS_EXCEPTION(ValueError) << "Bias size is invalid";
} }
} }
std::vector<int64_t> out_shape = {(int64_t)input0_shape.size()}; std::vector<int64_t> out_shape = {(int64_t)input0_shape.size()};

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -34,6 +34,7 @@ void L2NormalizeFusion::set_activation_type(const ActivationType &activation_typ
ActivationType L2NormalizeFusion::get_activation_type() const { ActivationType L2NormalizeFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType); auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr)); return ActivationType(GetValue<int64_t>(value_ptr));
} }
REGISTER_PRIMITIVE_C(kNameL2NormalizeFusion, L2NormalizeFusion); REGISTER_PRIMITIVE_C(kNameL2NormalizeFusion, L2NormalizeFusion);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -32,6 +32,7 @@ void LayerNormFusion::set_elementwise_affine(const bool elementwise_affine) {
bool LayerNormFusion::get_elementwise_affine() const { bool LayerNormFusion::get_elementwise_affine() const {
auto value_ptr = GetAttr(kElementwiseAffine); auto value_ptr = GetAttr(kElementwiseAffine);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr); return GetValue<bool>(value_ptr);
} }
REGISTER_PRIMITIVE_C(kNameLayerNormFusion, LayerNormFusion); REGISTER_PRIMITIVE_C(kNameLayerNormFusion, LayerNormFusion);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -40,24 +40,28 @@ void MaxPoolFusion::set_activation_type(ActivationType activation_type) {
bool MaxPoolFusion::get_global() const { bool MaxPoolFusion::get_global() const {
auto value_ptr = GetAttr(kGlobal); auto value_ptr = GetAttr(kGlobal);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr); return GetValue<bool>(value_ptr);
} }
ActivationType MaxPoolFusion::get_activation_type() const { ActivationType MaxPoolFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType); auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr)); return ActivationType(GetValue<int64_t>(value_ptr));
} }
namespace { namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
MS_EXCEPTION_IF_NULL(input_args[0]);
auto op_name = primitive->name(); auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) { if (format == NHWC) {
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
} }
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, 4, op_name); const int64_t in_shape_size = 4;
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, in_shape_size, op_name);
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)); auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode))); auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
auto batch = in_shape[0]; auto batch = in_shape[0];
@ -84,21 +88,22 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
out_shape = {batch, out_h, out_w, channel}; out_shape = {batch, out_h, out_w, channel};
} }
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
MS_LOG(EXCEPTION) << "Kernel size is not valid."; MS_LOG(EXCEPTION) << "Kernel size is invalid.";
} }
return std::make_shared<abstract::Shape>(out_shape); return std::make_shared<abstract::Shape>(out_shape);
} }
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) {
MS_LOG(EXCEPTION) << "nullptr";
}
return input_args[0]->BuildType(); return input_args[0]->BuildType();
} }
} // namespace } // namespace
AbstractBasePtr MaxPoolFusionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr MaxPoolFusionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
for (auto item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape()); InferShape(primitive, input_args)->shape());
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -29,6 +29,7 @@ void MulFusion::set_activation_type(const ActivationType &activation_type) {
} }
ActivationType MulFusion::get_activation_type() const { ActivationType MulFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType); auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr)); return ActivationType(GetValue<int64_t>(value_ptr));
} }
void MulFusion::Init(const ActivationType &activation_type) { this->set_activation_type(activation_type); } void MulFusion::Init(const ActivationType &activation_type) { this->set_activation_type(activation_type); }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -39,10 +39,12 @@ void PadFusion::set_constant_value(const float constant_value) {
PaddingMode PadFusion::get_padding_mode() const { PaddingMode PadFusion::get_padding_mode() const {
auto value_ptr = GetAttr(kPaddingMode); auto value_ptr = GetAttr(kPaddingMode);
MS_EXCEPTION_IF_NULL(value_ptr);
return PaddingMode(GetValue<int64_t>(value_ptr)); return PaddingMode(GetValue<int64_t>(value_ptr));
} }
float PadFusion::get_constant_value() const { float PadFusion::get_constant_value() const {
auto value_ptr = GetAttr(kConstantValue); auto value_ptr = GetAttr(kConstantValue);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr); return GetValue<float>(value_ptr);
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -25,6 +25,7 @@ void PartialFusion::set_sub_graph_index(const int64_t sub_graph_index) {
} }
int64_t PartialFusion::get_sub_graph_index() const { int64_t PartialFusion::get_sub_graph_index() const {
auto value_ptr = GetAttr(kSubGraphIndex); auto value_ptr = GetAttr(kSubGraphIndex);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr); return GetValue<int64_t>(value_ptr);
} }
REGISTER_PRIMITIVE_C(kNamePartialFusion, PartialFusion); REGISTER_PRIMITIVE_C(kNamePartialFusion, PartialFusion);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -42,6 +42,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
} }
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
for (const auto &item : input_args) { for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
@ -54,6 +55,10 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
AbstractBasePtr PowFusionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr PowFusionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInteger("PowFusion infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
primitive->name());
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape()); InferShape(primitive, input_args)->shape());
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -36,10 +36,12 @@ void PReLUFusion::set_slope(const std::vector<float> &slope) { (void)this->AddAt
bool PReLUFusion::get_channel_shared() const { bool PReLUFusion::get_channel_shared() const {
auto value_ptr = GetAttr(kChannelShared); auto value_ptr = GetAttr(kChannelShared);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr); return GetValue<bool>(value_ptr);
} }
std::vector<float> PReLUFusion::get_slope() const { std::vector<float> PReLUFusion::get_slope() const {
auto value_ptr = GetAttr(kSlope); auto value_ptr = GetAttr(kSlope);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<float>>(value_ptr); return GetValue<std::vector<float>>(value_ptr);
} }
REGISTER_PRIMITIVE_C(kNamePReLUFusion, PReLUFusion); REGISTER_PRIMITIVE_C(kNamePReLUFusion, PReLUFusion);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -41,21 +41,25 @@ void ReduceFusion::set_coeff(const float coeff) { (void)this->AddAttr(kCoeff, Ma
bool ReduceFusion::get_keep_dims() const { bool ReduceFusion::get_keep_dims() const {
auto value_ptr = GetAttr(kKeepDims); auto value_ptr = GetAttr(kKeepDims);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr); return GetValue<bool>(value_ptr);
} }
ReduceMode ReduceFusion::get_mode() const { ReduceMode ReduceFusion::get_mode() const {
auto value_ptr = GetAttr(kMode); auto value_ptr = GetAttr(kMode);
MS_EXCEPTION_IF_NULL(value_ptr);
return ReduceMode(GetValue<int64_t>(value_ptr)); return ReduceMode(GetValue<int64_t>(value_ptr));
} }
bool ReduceFusion::get_reduce_to_end() const { bool ReduceFusion::get_reduce_to_end() const {
auto value_ptr = GetAttr(kReduceToEnd); auto value_ptr = GetAttr(kReduceToEnd);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr); return GetValue<bool>(value_ptr);
} }
float ReduceFusion::get_coeff() const { float ReduceFusion::get_coeff() const {
auto value_ptr = GetAttr(kCoeff); auto value_ptr = GetAttr(kCoeff);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr); return GetValue<float>(value_ptr);
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -32,6 +32,7 @@ void ScaleFusion::set_activation_type(const ActivationType &activation_type) {
ActivationType ScaleFusion::get_activation_type() const { ActivationType ScaleFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType); auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr)); return ActivationType(GetValue<int64_t>(value_ptr));
} }
REGISTER_PRIMITIVE_C(kNameScaleFusion, ScaleFusion); REGISTER_PRIMITIVE_C(kNameScaleFusion, ScaleFusion);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -26,6 +26,7 @@ void SliceFusion::set_axes(const std::vector<int64_t> &axes) { (void)this->AddAt
std::vector<int64_t> SliceFusion::get_axes() const { std::vector<int64_t> SliceFusion::get_axes() const {
auto value_ptr = GetAttr(kAxes); auto value_ptr = GetAttr(kAxes);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
@ -39,6 +40,8 @@ AbstractBasePtr SliceFusionInfer(const abstract::AnalysisEnginePtr &, const Prim
auto size_v = input_args[2]->BuildValue(); auto size_v = input_args[2]->BuildValue();
auto x_type = input_args[0]->BuildType(); auto x_type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(x_type); MS_EXCEPTION_IF_NULL(x_type);
MS_EXCEPTION_IF_NULL(begin_v);
MS_EXCEPTION_IF_NULL(size_v);
auto tensor_type = x_type->cast<TensorTypePtr>(); auto tensor_type = x_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type); MS_EXCEPTION_IF_NULL(tensor_type);
auto data_type = tensor_type->element(); auto data_type = tensor_type->element();

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -29,6 +29,7 @@ void SubFusion::set_activation_type(const ActivationType &activation_type) {
ActivationType SubFusion::get_activation_type() const { ActivationType SubFusion::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType); auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr)); return ActivationType(GetValue<int64_t>(value_ptr));
} }
REGISTER_PRIMITIVE_C(kNameSubFusion, SubFusion); REGISTER_PRIMITIVE_C(kNameSubFusion, SubFusion);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -26,6 +26,7 @@ void TileFusion::set_dims(const std::vector<int64_t> &dims) { (void)this->AddAtt
std::vector<int64_t> TileFusion::get_dims() const { std::vector<int64_t> TileFusion::get_dims() const {
auto value_ptr = GetAttr(kDims); auto value_ptr = GetAttr(kDims);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
REGISTER_PRIMITIVE_C(kNameTileFusion, TileFusion); REGISTER_PRIMITIVE_C(kNameTileFusion, TileFusion);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -32,11 +32,13 @@ void TopKFusion::set_largest(const int64_t largest) { (void)this->AddAttr(kLarge
int64_t TopKFusion::get_axis() const { int64_t TopKFusion::get_axis() const {
auto value_ptr = GetAttr(kAxis); auto value_ptr = GetAttr(kAxis);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr); return GetValue<int64_t>(value_ptr);
} }
int64_t TopKFusion::get_largest() const { int64_t TopKFusion::get_largest() const {
auto value_ptr = GetAttr(kLargest); auto value_ptr = GetAttr(kLargest);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr); return GetValue<int64_t>(value_ptr);
} }
REGISTER_PRIMITIVE_C(kNameTopKFusion, TopKFusion); REGISTER_PRIMITIVE_C(kNameTopKFusion, TopKFusion);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -38,6 +38,7 @@ void ActivationGrad::set_activation_type(const ActivationType &type) {
ActivationType ActivationGrad::get_activation_type() const { ActivationType ActivationGrad::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType); auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr)); return ActivationType(GetValue<int64_t>(value_ptr));
} }
@ -45,6 +46,7 @@ void ActivationGrad::set_alpha(const float alpha) { (void)this->AddAttr(kAlpha,
float ActivationGrad::get_alpha() const { float ActivationGrad::get_alpha() const {
auto value_ptr = GetAttr(kAlpha); auto value_ptr = GetAttr(kAlpha);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr); return GetValue<float>(value_ptr);
} }
REGISTER_PRIMITIVE_C(kNameActivationGrad, ActivationGrad); REGISTER_PRIMITIVE_C(kNameActivationGrad, ActivationGrad);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -31,6 +31,7 @@ void BatchNormGrad::set_epsilon(const float epsilon) { (void)this->AddAttr(kEpsi
float BatchNormGrad::get_epsilon() const { float BatchNormGrad::get_epsilon() const {
auto value_ptr = this->GetAttr(kEpsilon); auto value_ptr = this->GetAttr(kEpsilon);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr); return GetValue<float>(value_ptr);
} }
@ -38,15 +39,19 @@ void BatchNormGrad::set_is_training(const bool is_training) { this->AddAttr(kIsT
bool BatchNormGrad::get_is_training() const { bool BatchNormGrad::get_is_training() const {
auto value_ptr = this->GetAttr(kIsTraining); auto value_ptr = this->GetAttr(kIsTraining);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr); return GetValue<bool>(value_ptr);
} }
AbstractBasePtr BatchNormGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr BatchNormGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
MS_EXCEPTION_IF_NULL(input_args[1]); for (auto item : input_args) {
MS_EXCEPTION_IF_NULL(input_args[2]); MS_EXCEPTION_IF_NULL(item);
MS_EXCEPTION_IF_NULL(input_args[3]); }
const int64_t input_num = 5;
CheckAndConvertUtils::CheckInteger("BatchNormGrad infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
primitive->name());
auto y_backprop_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto y_backprop_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::Check("BatchNorm y_backprop_shape", y_backprop_shape, kEqual, "BatchNorm x_shape", x_shape); CheckAndConvertUtils::Check("BatchNorm y_backprop_shape", y_backprop_shape, kEqual, "BatchNorm x_shape", x_shape);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -38,9 +38,6 @@ abstract::ShapePtr BinaryCrossEntroyGradInferShape(const PrimitivePtr &primitive
} }
TypePtr BinaryCrossEntroyGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { TypePtr BinaryCrossEntroyGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
std::map<std::string, TypePtr> types; std::map<std::string, TypePtr> types;
types.emplace("x_shape", input_args[0]->BuildType()); types.emplace("x_shape", input_args[0]->BuildType());
@ -67,6 +64,13 @@ Reduction BinaryCrossEntropyGrad::get_reduction() const {
AbstractBasePtr BinaryCrossEntropyGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr BinaryCrossEntropyGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const int64_t input_num = 4;
CheckAndConvertUtils::CheckInteger("BinaryCrossEntropyGrad infer", SizeToLong(input_args.size()), kGreaterEqual,
input_num, primitive->name());
return std::make_shared<abstract::AbstractTensor>(BinaryCrossEntroyGradInferType(primitive, input_args), return std::make_shared<abstract::AbstractTensor>(BinaryCrossEntroyGradInferType(primitive, input_args),
BinaryCrossEntroyGradInferShape(primitive, input_args)->shape()); BinaryCrossEntroyGradInferShape(primitive, input_args)->shape());
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -30,6 +30,7 @@ void BNGrad::set_eps(const float eps) { (void)this->AddAttr(kEps, MakeValue(eps)
float BNGrad::get_eps() const { float BNGrad::get_eps() const {
auto value_ptr = this->GetAttr(kEps); auto value_ptr = this->GetAttr(kEps);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr); return GetValue<float>(value_ptr);
} }
@ -37,6 +38,7 @@ void BNGrad::set_momentum(const float momentum) { (void)this->AddAttr(kMomentum,
float BNGrad::get_momentum() const { float BNGrad::get_momentum() const {
auto value_ptr = this->GetAttr(kMomentum); auto value_ptr = this->GetAttr(kMomentum);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr); return GetValue<float>(value_ptr);
} }
REGISTER_PRIMITIVE_C(kNameBNGrad, BNGrad); REGISTER_PRIMITIVE_C(kNameBNGrad, BNGrad);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -71,6 +71,7 @@ void Conv2DBackpropFilter::set_out_channel(const int64_t out_channel) {
int64_t Conv2DBackpropFilter::get_out_channel() const { int64_t Conv2DBackpropFilter::get_out_channel() const {
auto value_ptr = GetAttr(kOutChannel); auto value_ptr = GetAttr(kOutChannel);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr); return GetValue<int64_t>(value_ptr);
} }
@ -80,6 +81,7 @@ void Conv2DBackpropFilter::set_kernel_size(const std::vector<int64_t> &kernel_si
std::vector<int64_t> Conv2DBackpropFilter::get_kernel_size() const { std::vector<int64_t> Conv2DBackpropFilter::get_kernel_size() const {
auto value_ptr = GetAttr(kKernelSize); auto value_ptr = GetAttr(kKernelSize);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
@ -90,6 +92,7 @@ void Conv2DBackpropFilter::set_pad_mode(const PadMode &pad_mode) {
PadMode Conv2DBackpropFilter::get_pad_mode() const { PadMode Conv2DBackpropFilter::get_pad_mode() const {
auto value_ptr = GetAttr(kPadMode); auto value_ptr = GetAttr(kPadMode);
MS_EXCEPTION_IF_NULL(value_ptr);
return PadMode(GetValue<int64_t>(value_ptr)); return PadMode(GetValue<int64_t>(value_ptr));
} }
@ -99,6 +102,7 @@ void Conv2DBackpropFilter::set_pad_list(const std::vector<int64_t> &pad_list) {
std::vector<int64_t> Conv2DBackpropFilter::get_pad_list() const { std::vector<int64_t> Conv2DBackpropFilter::get_pad_list() const {
auto value_ptr = GetAttr(kPadList); auto value_ptr = GetAttr(kPadList);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
@ -106,6 +110,7 @@ void Conv2DBackpropFilter::set_mode(const int64_t mode) { (void)this->AddAttr(kM
int64_t Conv2DBackpropFilter::get_mode() const { int64_t Conv2DBackpropFilter::get_mode() const {
auto value_ptr = GetAttr(kMode); auto value_ptr = GetAttr(kMode);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr); return GetValue<int64_t>(value_ptr);
} }
@ -113,6 +118,7 @@ void Conv2DBackpropFilter::set_stride(const std::vector<int64_t> &stride) { this
std::vector<int64_t> Conv2DBackpropFilter::get_stride() const { std::vector<int64_t> Conv2DBackpropFilter::get_stride() const {
auto value_ptr = GetAttr(kStride); auto value_ptr = GetAttr(kStride);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
@ -122,6 +128,7 @@ void Conv2DBackpropFilter::set_dilation(const std::vector<int64_t> &dilation) {
std::vector<int64_t> Conv2DBackpropFilter::get_dilation() const { std::vector<int64_t> Conv2DBackpropFilter::get_dilation() const {
auto value_ptr = GetAttr(kDilation); auto value_ptr = GetAttr(kDilation);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
@ -129,6 +136,7 @@ void Conv2DBackpropFilter::set_group(const int64_t group) { (void)this->AddAttr(
int64_t Conv2DBackpropFilter::get_group() const { int64_t Conv2DBackpropFilter::get_group() const {
auto value_ptr = GetAttr(kGroup); auto value_ptr = GetAttr(kGroup);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr); return GetValue<int64_t>(value_ptr);
} }
@ -139,6 +147,7 @@ void Conv2DBackpropFilter::set_format(const Format &format) {
Format Conv2DBackpropFilter::get_format() const { Format Conv2DBackpropFilter::get_format() const {
auto value_ptr = GetAttr(kFormat); auto value_ptr = GetAttr(kFormat);
MS_EXCEPTION_IF_NULL(value_ptr);
return Format(GetValue<int64_t>(value_ptr)); return Format(GetValue<int64_t>(value_ptr));
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -172,51 +172,61 @@ void Conv2DBackpropInput::set_pad_list(const std::vector<int64_t> &pad_list) {
int64_t Conv2DBackpropInput::get_out_channel() const { int64_t Conv2DBackpropInput::get_out_channel() const {
auto value_ptr = GetAttr(kOutChannel); auto value_ptr = GetAttr(kOutChannel);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr); return GetValue<int64_t>(value_ptr);
} }
std::vector<int64_t> Conv2DBackpropInput::get_kernel_size() const { std::vector<int64_t> Conv2DBackpropInput::get_kernel_size() const {
auto value_ptr = GetAttr(kKernelSize); auto value_ptr = GetAttr(kKernelSize);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
std::vector<int64_t> Conv2DBackpropInput::get_stride() const { std::vector<int64_t> Conv2DBackpropInput::get_stride() const {
auto value_ptr = GetAttr(kStride); auto value_ptr = GetAttr(kStride);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
std::vector<int64_t> Conv2DBackpropInput::get_dilation() const { std::vector<int64_t> Conv2DBackpropInput::get_dilation() const {
auto value_ptr = GetAttr(kDilation); auto value_ptr = GetAttr(kDilation);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
PadMode Conv2DBackpropInput::get_pad_mode() const { PadMode Conv2DBackpropInput::get_pad_mode() const {
auto value_ptr = GetAttr(kPadMode); auto value_ptr = GetAttr(kPadMode);
MS_EXCEPTION_IF_NULL(value_ptr);
return PadMode(GetValue<int64_t>(value_ptr)); return PadMode(GetValue<int64_t>(value_ptr));
} }
std::vector<int64_t> Conv2DBackpropInput::get_pad() const { std::vector<int64_t> Conv2DBackpropInput::get_pad() const {
auto value_ptr = GetAttr(kPad); auto value_ptr = GetAttr(kPad);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
int64_t Conv2DBackpropInput::get_mode() const { int64_t Conv2DBackpropInput::get_mode() const {
auto value_ptr = GetAttr(kMode); auto value_ptr = GetAttr(kMode);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr); return GetValue<int64_t>(value_ptr);
} }
int64_t Conv2DBackpropInput::get_group() const { int64_t Conv2DBackpropInput::get_group() const {
auto value_ptr = GetAttr(kGroup); auto value_ptr = GetAttr(kGroup);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr); return GetValue<int64_t>(value_ptr);
} }
Format Conv2DBackpropInput::get_format() const { Format Conv2DBackpropInput::get_format() const {
auto value_ptr = GetAttr(kFormat); auto value_ptr = GetAttr(kFormat);
MS_EXCEPTION_IF_NULL(value_ptr);
return Format(GetValue<int64_t>(value_ptr)); return Format(GetValue<int64_t>(value_ptr));
} }
std::vector<int64_t> Conv2DBackpropInput::get_pad_list() const { std::vector<int64_t> Conv2DBackpropInput::get_pad_list() const {
auto value_ptr = GetAttr(kPadList); auto value_ptr = GetAttr(kPadList);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropInput, prim::kPrimConv2DBackpropInput, Conv2DBackpropInputInfer, nullptr, REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropInput, prim::kPrimConv2DBackpropInput, Conv2DBackpropInputInfer, nullptr,

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -45,6 +45,7 @@ void DeConv2DGradFilter::set_in_channel(const int64_t in_channel) {
int64_t DeConv2DGradFilter::get_in_channel() const { int64_t DeConv2DGradFilter::get_in_channel() const {
auto value_ptr = GetAttr(kInChannel); auto value_ptr = GetAttr(kInChannel);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr); return GetValue<int64_t>(value_ptr);
} }
@ -54,6 +55,7 @@ void DeConv2DGradFilter::set_out_channel(const int64_t out_channel) {
int64_t DeConv2DGradFilter::get_out_channel() const { int64_t DeConv2DGradFilter::get_out_channel() const {
auto value_ptr = GetAttr(kOutChannel); auto value_ptr = GetAttr(kOutChannel);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr); return GetValue<int64_t>(value_ptr);
} }
@ -63,6 +65,7 @@ void DeConv2DGradFilter::set_kernel_size(const std::vector<int64_t> &kernel_size
std::vector<int64_t> DeConv2DGradFilter::get_kernel_size() const { std::vector<int64_t> DeConv2DGradFilter::get_kernel_size() const {
auto value_ptr = GetAttr(kKernelSize); auto value_ptr = GetAttr(kKernelSize);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
@ -73,6 +76,7 @@ void DeConv2DGradFilter::set_pad_mode(const PadMode &pad_mode) {
PadMode DeConv2DGradFilter::get_pad_mode() const { PadMode DeConv2DGradFilter::get_pad_mode() const {
auto value_ptr = GetAttr(kPadMode); auto value_ptr = GetAttr(kPadMode);
MS_EXCEPTION_IF_NULL(value_ptr);
return PadMode(GetValue<int64_t>(value_ptr)); return PadMode(GetValue<int64_t>(value_ptr));
} }
@ -82,6 +86,7 @@ void DeConv2DGradFilter::set_pad_list(const std::vector<int64_t> &pad_list) {
std::vector<int64_t> DeConv2DGradFilter::get_pad_list() const { std::vector<int64_t> DeConv2DGradFilter::get_pad_list() const {
auto value_ptr = GetAttr(kPadList); auto value_ptr = GetAttr(kPadList);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
@ -91,6 +96,7 @@ void DeConv2DGradFilter::set_stride(const std::vector<int64_t> &stride) {
std::vector<int64_t> DeConv2DGradFilter::get_stride() const { std::vector<int64_t> DeConv2DGradFilter::get_stride() const {
auto value_ptr = GetAttr(kStride); auto value_ptr = GetAttr(kStride);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
@ -100,6 +106,7 @@ void DeConv2DGradFilter::set_dilation(const std::vector<int64_t> &dilation) {
std::vector<int64_t> DeConv2DGradFilter::get_dilation() const { std::vector<int64_t> DeConv2DGradFilter::get_dilation() const {
auto value_ptr = GetAttr(kDilation); auto value_ptr = GetAttr(kDilation);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
@ -107,6 +114,7 @@ void DeConv2DGradFilter::set_group(const int64_t group) { (void)this->AddAttr(kG
int64_t DeConv2DGradFilter::get_group() const { int64_t DeConv2DGradFilter::get_group() const {
auto value_ptr = GetAttr(kGroup); auto value_ptr = GetAttr(kGroup);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr); return GetValue<int64_t>(value_ptr);
} }
@ -117,6 +125,7 @@ void DeConv2DGradFilter::set_format(const Format &format) {
Format DeConv2DGradFilter::get_format() const { Format DeConv2DGradFilter::get_format() const {
auto value_ptr = GetAttr(kFormat); auto value_ptr = GetAttr(kFormat);
MS_EXCEPTION_IF_NULL(value_ptr);
return Format(GetValue<int64_t>(value_ptr)); return Format(GetValue<int64_t>(value_ptr));
} }
@ -127,6 +136,7 @@ void DeConv2DGradFilter::set_activation_type(const ActivationType &activation_ty
ActivationType DeConv2DGradFilter::get_activation_type() const { ActivationType DeConv2DGradFilter::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType); auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr)); return ActivationType(GetValue<int64_t>(value_ptr));
} }
@ -134,6 +144,7 @@ void DeConv2DGradFilter::set_has_bias(const bool has_bias) { (void)this->AddAttr
bool DeConv2DGradFilter::get_has_bias() const { bool DeConv2DGradFilter::get_has_bias() const {
auto value_ptr = GetAttr(kHasBias); auto value_ptr = GetAttr(kHasBias);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr); return GetValue<bool>(value_ptr);
} }
REGISTER_PRIMITIVE_C(kNameDeConv2DGradFilter, DeConv2DGradFilter); REGISTER_PRIMITIVE_C(kNameDeConv2DGradFilter, DeConv2DGradFilter);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -28,6 +28,7 @@ void DropoutGrad::set_keep_prob(const float keep_prob) {
float DropoutGrad::get_keep_prob() const { float DropoutGrad::get_keep_prob() const {
auto value_ptr = GetAttr(kKeepProb); auto value_ptr = GetAttr(kKeepProb);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<float>(value_ptr); return GetValue<float>(value_ptr);
} }
@ -55,6 +56,13 @@ TypePtr DropoutGradInferType(const PrimitivePtr &prim, const std::vector<Abstrac
AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
for (auto item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInteger("DropoutGrad infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
primitive->name());
return std::make_shared<abstract::AbstractTensor>(DropoutGradInferType(primitive, input_args), return std::make_shared<abstract::AbstractTensor>(DropoutGradInferType(primitive, input_args),
DropoutGradInferShape(primitive, input_args)->shape()); DropoutGradInferShape(primitive, input_args)->shape());
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -22,7 +22,8 @@ AbstractBasePtr FlattenGradInfer(const abstract::AnalysisEnginePtr &, const Prim
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim_name); const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) { for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -46,6 +46,7 @@ void GroupConv2DGradInput::set_in_channel(const int64_t &in_channel) {
int64_t GroupConv2DGradInput::get_in_channel() const { int64_t GroupConv2DGradInput::get_in_channel() const {
auto value_ptr = GetAttr(kInChannel); auto value_ptr = GetAttr(kInChannel);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr); return GetValue<int64_t>(value_ptr);
} }
@ -55,6 +56,7 @@ void GroupConv2DGradInput::set_out_channel(const int64_t &out_channel) {
int64_t GroupConv2DGradInput::get_out_channel() const { int64_t GroupConv2DGradInput::get_out_channel() const {
auto value_ptr = GetAttr(kOutChannel); auto value_ptr = GetAttr(kOutChannel);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr); return GetValue<int64_t>(value_ptr);
} }
@ -64,6 +66,7 @@ void GroupConv2DGradInput::set_kernel_size(const std::vector<int64_t> &kernel_si
std::vector<int64_t> GroupConv2DGradInput::get_kernel_size() const { std::vector<int64_t> GroupConv2DGradInput::get_kernel_size() const {
auto value_ptr = GetAttr(kKernelSize); auto value_ptr = GetAttr(kKernelSize);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
@ -74,6 +77,7 @@ void GroupConv2DGradInput::set_pad_mode(const PadMode &pad_mode) {
PadMode GroupConv2DGradInput::get_pad_mode() const { PadMode GroupConv2DGradInput::get_pad_mode() const {
auto value_ptr = GetAttr(kPadMode); auto value_ptr = GetAttr(kPadMode);
MS_EXCEPTION_IF_NULL(value_ptr);
return PadMode(GetValue<int64_t>(value_ptr)); return PadMode(GetValue<int64_t>(value_ptr));
} }
@ -83,6 +87,7 @@ void GroupConv2DGradInput::set_pad_list(const std::vector<int64_t> &pad_list) {
std::vector<int64_t> GroupConv2DGradInput::get_pad_list() const { std::vector<int64_t> GroupConv2DGradInput::get_pad_list() const {
auto value_ptr = GetAttr(kPadList); auto value_ptr = GetAttr(kPadList);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
@ -90,6 +95,7 @@ void GroupConv2DGradInput::set_stride(const std::vector<int64_t> &stride) { this
std::vector<int64_t> GroupConv2DGradInput::get_stride() const { std::vector<int64_t> GroupConv2DGradInput::get_stride() const {
auto value_ptr = GetAttr(kStride); auto value_ptr = GetAttr(kStride);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
@ -99,6 +105,7 @@ void GroupConv2DGradInput::set_dilation(const std::vector<int64_t> &dilation) {
std::vector<int64_t> GroupConv2DGradInput::get_dilation() const { std::vector<int64_t> GroupConv2DGradInput::get_dilation() const {
auto value_ptr = GetAttr(kDilation); auto value_ptr = GetAttr(kDilation);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
@ -106,6 +113,7 @@ void GroupConv2DGradInput::set_group(const int64_t &group) { (void)this->AddAttr
int64_t GroupConv2DGradInput::get_group() const { int64_t GroupConv2DGradInput::get_group() const {
auto value_ptr = GetAttr(kGroup); auto value_ptr = GetAttr(kGroup);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr); return GetValue<int64_t>(value_ptr);
} }
@ -114,7 +122,9 @@ void GroupConv2DGradInput::set_input_shape(const std::vector<int64_t> &input_sha
} }
std::vector<int64_t> GroupConv2DGradInput::get_input_shape() const { std::vector<int64_t> GroupConv2DGradInput::get_input_shape() const {
return GetValue<std::vector<int64_t>>(GetAttr(kInputShape)); auto value_ptr = GetAttr(kInputShape);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
} }
void GroupConv2DGradInput::set_format(const Format &format) { void GroupConv2DGradInput::set_format(const Format &format) {
@ -124,6 +134,7 @@ void GroupConv2DGradInput::set_format(const Format &format) {
Format GroupConv2DGradInput::get_format() const { Format GroupConv2DGradInput::get_format() const {
auto value_ptr = GetAttr(kFormat); auto value_ptr = GetAttr(kFormat);
MS_EXCEPTION_IF_NULL(value_ptr);
return Format(GetValue<int64_t>(value_ptr)); return Format(GetValue<int64_t>(value_ptr));
} }
@ -134,6 +145,7 @@ void GroupConv2DGradInput::set_activation_type(const ActivationType &activation_
ActivationType GroupConv2DGradInput::get_activation_type() const { ActivationType GroupConv2DGradInput::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType); auto value_ptr = GetAttr(kActivationType);
MS_EXCEPTION_IF_NULL(value_ptr);
return ActivationType(GetValue<int64_t>(value_ptr)); return ActivationType(GetValue<int64_t>(value_ptr));
} }
@ -141,6 +153,7 @@ void GroupConv2DGradInput::set_has_bias(const bool has_bias) { (void)this->AddAt
bool GroupConv2DGradInput::get_has_bias() const { bool GroupConv2DGradInput::get_has_bias() const {
auto value_ptr = GetAttr(kHasBias); auto value_ptr = GetAttr(kHasBias);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<bool>(value_ptr); return GetValue<bool>(value_ptr);
} }
AbstractBasePtr GroupConv2DGradInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr GroupConv2DGradInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
@ -152,11 +165,16 @@ AbstractBasePtr GroupConv2DGradInputInfer(const abstract::AnalysisEnginePtr &, c
MS_EXCEPTION_IF_NULL(input_args[0]); MS_EXCEPTION_IF_NULL(input_args[0]);
// Infer shape // Infer shape
auto shape = GetValue<std::vector<int64_t>>(primitive->GetAttr(kInputShape)); auto shape_ptr = primitive->GetAttr(kInputShape);
MS_EXCEPTION_IF_NULL(shape_ptr);
auto shape = GetValue<std::vector<int64_t>>(shape_ptr);
// Infer type // Infer type
auto type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); auto type_ptr = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(type_ptr);
auto type_tensor_ptr = type_ptr->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(type_tensor_ptr);
auto type = type_tensor_ptr->element();
return std::make_shared<abstract::AbstractTensor>(type, shape); return std::make_shared<abstract::AbstractTensor>(type, shape);
} }
REGISTER_PRIMITIVE_C(kNameGroupConv2DGradInput, GroupConv2DGradInput); REGISTER_PRIMITIVE_C(kNameGroupConv2DGradInput, GroupConv2DGradInput);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -48,10 +48,12 @@ void LayerNormGrad::set_begin_params_axis(const int64_t begin_params_axis) {
} }
int64_t LayerNormGrad::get_begin_norm_axis() const { int64_t LayerNormGrad::get_begin_norm_axis() const {
auto value_ptr = this->GetAttr(kBeginNormAxis); auto value_ptr = this->GetAttr(kBeginNormAxis);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr); return GetValue<int64_t>(value_ptr);
} }
int64_t LayerNormGrad::get_begin_params_axis() const { int64_t LayerNormGrad::get_begin_params_axis() const {
auto value_ptr = this->GetAttr(kBeginParamsAxis); auto value_ptr = this->GetAttr(kBeginParamsAxis);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr); return GetValue<int64_t>(value_ptr);
} }
REGISTER_PRIMITIVE_EVAL_IMPL(LayerNormGrad, prim::kPrimLayerNormGrad, LayerNormGradInfer, nullptr, true); REGISTER_PRIMITIVE_EVAL_IMPL(LayerNormGrad, prim::kPrimLayerNormGrad, LayerNormGradInfer, nullptr, true);

Some files were not shown because too many files have changed in this diff Show More