forked from mindspore-Ecosystem/mindspore
commit
ad589e6780
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -39,6 +39,7 @@ class RecvGpuKernel : public GpuKernel {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
bool Init(const CNodePtr &kernel_node) override {
|
bool Init(const CNodePtr &kernel_node) override {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
kernel_node_ = kernel_node;
|
kernel_node_ = kernel_node;
|
||||||
wait_stream_ = reinterpret_cast<cudaStream_t>(GetAttr<uintptr_t>(kernel_node, "wait_event_stream"));
|
wait_stream_ = reinterpret_cast<cudaStream_t>(GetAttr<uintptr_t>(kernel_node, "wait_event_stream"));
|
||||||
wait_event_ = reinterpret_cast<cudaEvent_t>(GetAttr<uintptr_t>(kernel_node, "wait_event"));
|
wait_event_ = reinterpret_cast<cudaEvent_t>(GetAttr<uintptr_t>(kernel_node, "wait_event"));
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -39,6 +39,7 @@ class SendGpuKernel : public GpuKernel {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
bool Init(const CNodePtr &kernel_node) override {
|
bool Init(const CNodePtr &kernel_node) override {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
kernel_node_ = kernel_node;
|
kernel_node_ = kernel_node;
|
||||||
record_stream_ = reinterpret_cast<cudaStream_t>(GetAttr<uintptr_t>(kernel_node, "record_event_stream"));
|
record_stream_ = reinterpret_cast<cudaStream_t>(GetAttr<uintptr_t>(kernel_node, "record_event_stream"));
|
||||||
record_event_ = reinterpret_cast<cudaEvent_t>(GetAttr<uintptr_t>(kernel_node, "record_event"));
|
record_event_ = reinterpret_cast<cudaEvent_t>(GetAttr<uintptr_t>(kernel_node, "record_event"));
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -34,11 +34,17 @@ const std::vector<size_t> &DatasetInitKernel::GetOutputSizeList() const { return
|
||||||
const std::vector<size_t> &DatasetInitKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
|
const std::vector<size_t> &DatasetInitKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
|
||||||
|
|
||||||
bool DatasetInitKernel::Init(const CNodePtr &kernel_node) {
|
bool DatasetInitKernel::Init(const CNodePtr &kernel_node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
queue_name_ = GetAttr<std::string>(kernel_node, "queue_name");
|
queue_name_ = GetAttr<std::string>(kernel_node, "queue_name");
|
||||||
std::vector<std::vector<int>> shapes;
|
std::vector<std::vector<int>> shapes;
|
||||||
std::vector<TypePtr> types;
|
std::vector<TypePtr> types;
|
||||||
GetShapeAndType(kernel_node, &shapes, &types);
|
GetShapeAndType(kernel_node, &shapes, &types);
|
||||||
|
for (auto item : types) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
|
if (types.size() < shapes.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "types size is less than shapes size.";
|
||||||
|
}
|
||||||
for (size_t i = 0; i < shapes.size(); i++) {
|
for (size_t i = 0; i < shapes.size(); i++) {
|
||||||
int unit = UnitSizeInBytes(types[i]->type_id());
|
int unit = UnitSizeInBytes(types[i]->type_id());
|
||||||
int nums = ElementNums(shapes[i]);
|
int nums = ElementNums(shapes[i]);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -47,12 +47,18 @@ const std::vector<size_t> &DatasetIteratorKernel::GetOutputSizeList() const { re
|
||||||
const std::vector<size_t> &DatasetIteratorKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
|
const std::vector<size_t> &DatasetIteratorKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
|
||||||
|
|
||||||
bool DatasetIteratorKernel::Init(const CNodePtr &kernel_node) {
|
bool DatasetIteratorKernel::Init(const CNodePtr &kernel_node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
kernel_node_ = kernel_node;
|
kernel_node_ = kernel_node;
|
||||||
queue_name_ = GetAttr<std::string>(kernel_node, "shared_name");
|
queue_name_ = GetAttr<std::string>(kernel_node, "shared_name");
|
||||||
std::vector<std::vector<int>> shapes;
|
std::vector<std::vector<int>> shapes;
|
||||||
std::vector<TypePtr> types;
|
std::vector<TypePtr> types;
|
||||||
GetShapeAndType(kernel_node, &shapes, &types);
|
GetShapeAndType(kernel_node, &shapes, &types);
|
||||||
|
for (auto item : types) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
|
if (types.size() < shapes.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "types size is less than shapes size.";
|
||||||
|
}
|
||||||
for (size_t i = 0; i < shapes.size(); i++) {
|
for (size_t i = 0; i < shapes.size(); i++) {
|
||||||
int unit = UnitSizeInBytes(types[i]->type_id());
|
int unit = UnitSizeInBytes(types[i]->type_id());
|
||||||
int nums = ElementNums(shapes[i]);
|
int nums = ElementNums(shapes[i]);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -48,6 +48,9 @@ void GetNextProfiling::SaveProfilingData() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (uint32_t index = 0; index < queue_size_.size(); index++) {
|
for (uint32_t index = 0; index < queue_size_.size(); index++) {
|
||||||
|
if (index > time_stamp_.size() - 1) {
|
||||||
|
MS_LOG(EXCEPTION) << "index exceeds time_stamp_ size.";
|
||||||
|
}
|
||||||
handle << Name() << " " << time_stamp_[index].first << " " << time_stamp_[index].second << " " << queue_size_[index]
|
handle << Name() << " " << time_stamp_[index].first << " " << time_stamp_[index].second << " " << queue_size_[index]
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -68,6 +68,7 @@ int ElementNums(const std::vector<int> &shape) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetShapeAndType(const CNodePtr &kernel_node, std::vector<std::vector<int>> *shapes, std::vector<TypePtr> *types) {
|
void GetShapeAndType(const CNodePtr &kernel_node, std::vector<std::vector<int>> *shapes, std::vector<TypePtr> *types) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
MS_EXCEPTION_IF_NULL(shapes);
|
MS_EXCEPTION_IF_NULL(shapes);
|
||||||
MS_EXCEPTION_IF_NULL(types);
|
MS_EXCEPTION_IF_NULL(types);
|
||||||
std::vector<std::vector<int64_t>> shapes_me =
|
std::vector<std::vector<int64_t>> shapes_me =
|
||||||
|
|
|
@ -87,6 +87,7 @@ class PrintGpuKernel : public GpuKernel {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Init(const CNodePtr &kernel_node) override {
|
bool Init(const CNodePtr &kernel_node) override {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
kernel_node_ = kernel_node;
|
kernel_node_ = kernel_node;
|
||||||
if (AnfAlgo::HasNodeAttr("string_pos", kernel_node)) {
|
if (AnfAlgo::HasNodeAttr("string_pos", kernel_node)) {
|
||||||
string_value_ = GetAttr<std::vector<std::string>>(kernel_node, "string_value");
|
string_value_ = GetAttr<std::vector<std::string>>(kernel_node, "string_value");
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -77,6 +77,7 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Init(const CNodePtr &kernel_node) override {
|
bool Init(const CNodePtr &kernel_node) override {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
kernel_node_ = kernel_node;
|
kernel_node_ = kernel_node;
|
||||||
nccl_data_type_ = nccl_dtype(AnfAlgo::GetInputDeviceDataType(kernel_node, 0));
|
nccl_data_type_ = nccl_dtype(AnfAlgo::GetInputDeviceDataType(kernel_node, 0));
|
||||||
InferCommType(kernel_node);
|
InferCommType(kernel_node);
|
||||||
|
|
|
@ -49,6 +49,7 @@ class AssignGpuKernel : public GpuKernel {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Init(const CNodePtr &kernel_node) override {
|
bool Init(const CNodePtr &kernel_node) override {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
kernel_node_ = kernel_node;
|
kernel_node_ = kernel_node;
|
||||||
if (!CheckParam(kernel_node)) {
|
if (!CheckParam(kernel_node)) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -71,6 +72,7 @@ class AssignGpuKernel : public GpuKernel {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool CheckParam(const CNodePtr &kernel_node) {
|
bool CheckParam(const CNodePtr &kernel_node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
if (input_num != 2) {
|
if (input_num != 2) {
|
||||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but AssignGpuKernel needs 2 output.";
|
MS_LOG(ERROR) << "Input number is " << input_num << ", but AssignGpuKernel needs 2 output.";
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -62,6 +62,7 @@ class BoundingBoxDecodeGpuKernel : public GpuKernel {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Init(const CNodePtr &kernel_node) override {
|
bool Init(const CNodePtr &kernel_node) override {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
if (input_num != 2) {
|
if (input_num != 2) {
|
||||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but BoundingBoxDecode needs 2 inputs.";
|
MS_LOG(ERROR) << "Input number is " << input_num << ", but BoundingBoxDecode needs 2 inputs.";
|
||||||
|
@ -89,10 +90,11 @@ class BoundingBoxDecodeGpuKernel : public GpuKernel {
|
||||||
InitSizeLists();
|
InitSizeLists();
|
||||||
|
|
||||||
const size_t coordinate_size = 4;
|
const size_t coordinate_size = 4;
|
||||||
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueTuple>() ||
|
auto means = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means");
|
||||||
AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueList>()) {
|
MS_EXCEPTION_IF_NULL(means);
|
||||||
|
if (means->isa<ValueTuple>() || means->isa<ValueList>()) {
|
||||||
means_ = GetAttr<std::vector<float>>(kernel_node, "means");
|
means_ = GetAttr<std::vector<float>>(kernel_node, "means");
|
||||||
} else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<FloatImm>()) {
|
} else if (means->isa<FloatImm>()) {
|
||||||
float mean = GetAttr<float>(kernel_node, "means");
|
float mean = GetAttr<float>(kernel_node, "means");
|
||||||
for (size_t i = 0; i < coordinate_size; i++) {
|
for (size_t i = 0; i < coordinate_size; i++) {
|
||||||
means_.emplace_back(mean);
|
means_.emplace_back(mean);
|
||||||
|
@ -101,10 +103,11 @@ class BoundingBoxDecodeGpuKernel : public GpuKernel {
|
||||||
MS_LOG(EXCEPTION) << "Attribute means type is invalid.";
|
MS_LOG(EXCEPTION) << "Attribute means type is invalid.";
|
||||||
}
|
}
|
||||||
|
|
||||||
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueTuple>() ||
|
auto stds = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds");
|
||||||
AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueList>()) {
|
MS_EXCEPTION_IF_NULL(stds);
|
||||||
|
if (stds->isa<ValueTuple>() || stds->isa<ValueList>()) {
|
||||||
stds_ = GetAttr<std::vector<float>>(kernel_node, "stds");
|
stds_ = GetAttr<std::vector<float>>(kernel_node, "stds");
|
||||||
} else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<FloatImm>()) {
|
} else if (stds->isa<FloatImm>()) {
|
||||||
float std = GetAttr<float>(kernel_node, "stds");
|
float std = GetAttr<float>(kernel_node, "stds");
|
||||||
for (size_t i = 0; i < coordinate_size; i++) {
|
for (size_t i = 0; i < coordinate_size; i++) {
|
||||||
stds_.emplace_back(std);
|
stds_.emplace_back(std);
|
||||||
|
|
|
@ -61,6 +61,7 @@ class BoundingBoxEncodeGpuKernel : public GpuKernel {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Init(const CNodePtr &kernel_node) override {
|
bool Init(const CNodePtr &kernel_node) override {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
if (input_num != 2) {
|
if (input_num != 2) {
|
||||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but BoundingBoxEncode needs 2 inputs.";
|
MS_LOG(ERROR) << "Input number is " << input_num << ", but BoundingBoxEncode needs 2 inputs.";
|
||||||
|
@ -88,10 +89,11 @@ class BoundingBoxEncodeGpuKernel : public GpuKernel {
|
||||||
InitSizeLists();
|
InitSizeLists();
|
||||||
|
|
||||||
const size_t coordinate_size = 4;
|
const size_t coordinate_size = 4;
|
||||||
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueTuple>() ||
|
auto means = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means");
|
||||||
AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueList>()) {
|
MS_EXCEPTION_IF_NULL(means);
|
||||||
|
if (means->isa<ValueTuple>() || means->isa<ValueList>()) {
|
||||||
means_ = GetAttr<std::vector<float>>(kernel_node, "means");
|
means_ = GetAttr<std::vector<float>>(kernel_node, "means");
|
||||||
} else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<FloatImm>()) {
|
} else if (means->isa<FloatImm>()) {
|
||||||
float mean = GetAttr<float>(kernel_node, "means");
|
float mean = GetAttr<float>(kernel_node, "means");
|
||||||
for (size_t i = 0; i < coordinate_size; i++) {
|
for (size_t i = 0; i < coordinate_size; i++) {
|
||||||
means_.emplace_back(mean);
|
means_.emplace_back(mean);
|
||||||
|
@ -99,11 +101,11 @@ class BoundingBoxEncodeGpuKernel : public GpuKernel {
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "Attribute means type is invalid.";
|
MS_LOG(EXCEPTION) << "Attribute means type is invalid.";
|
||||||
}
|
}
|
||||||
|
auto stds = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds");
|
||||||
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueTuple>() ||
|
MS_EXCEPTION_IF_NULL(stds);
|
||||||
AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueList>()) {
|
if (stds->isa<ValueTuple>() || stds->isa<ValueList>()) {
|
||||||
stds_ = GetAttr<std::vector<float>>(kernel_node, "stds");
|
stds_ = GetAttr<std::vector<float>>(kernel_node, "stds");
|
||||||
} else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<FloatImm>()) {
|
} else if (stds->isa<FloatImm>()) {
|
||||||
float std = GetAttr<float>(kernel_node, "stds");
|
float std = GetAttr<float>(kernel_node, "stds");
|
||||||
for (size_t i = 0; i < coordinate_size; i++) {
|
for (size_t i = 0; i < coordinate_size; i++) {
|
||||||
stds_.emplace_back(std);
|
stds_.emplace_back(std);
|
||||||
|
|
|
@ -55,6 +55,7 @@ class CheckValidGpuKernel : public GpuKernel {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Init(const CNodePtr &kernel_node) override {
|
bool Init(const CNodePtr &kernel_node) override {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
if (input_num != 2) {
|
if (input_num != 2) {
|
||||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but CheckValid needs 2 inputs.";
|
MS_LOG(ERROR) << "Input number is " << input_num << ", but CheckValid needs 2 inputs.";
|
||||||
|
|
|
@ -59,6 +59,7 @@ class GpuConvertToDynamicShapeGpuKernel : public GpuKernel {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Init(const CNodePtr &kernel_node) override {
|
bool Init(const CNodePtr &kernel_node) override {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
kernel_node_ = kernel_node;
|
kernel_node_ = kernel_node;
|
||||||
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
|
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
if (input_count != 1) {
|
if (input_count != 1) {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -59,6 +59,7 @@ class IOUGpuKernel : public GpuKernel {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Init(const CNodePtr &kernel_node) override {
|
bool Init(const CNodePtr &kernel_node) override {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
if (input_num != 2) {
|
if (input_num != 2) {
|
||||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but IOU needs 2 inputs.";
|
MS_LOG(ERROR) << "Input number is " << input_num << ", but IOU needs 2 inputs.";
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -89,6 +89,7 @@ class RandomCategoricalGpuKernel : public GpuKernel {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Init(const CNodePtr &kernel_node) override {
|
bool Init(const CNodePtr &kernel_node) override {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
kernel_node_ = kernel_node;
|
kernel_node_ = kernel_node;
|
||||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
if (input_num != 3) {
|
if (input_num != 3) {
|
||||||
|
|
|
@ -70,6 +70,7 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Init(const CNodePtr &kernel_node) override {
|
bool Init(const CNodePtr &kernel_node) override {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
uint32_t time_interval = std::chrono::system_clock::now().time_since_epoch().count();
|
uint32_t time_interval = std::chrono::system_clock::now().time_since_epoch().count();
|
||||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
if (input_num != 1) {
|
if (input_num != 1) {
|
||||||
|
|
|
@ -76,6 +76,7 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Init(const CNodePtr &kernel_node) override {
|
bool Init(const CNodePtr &kernel_node) override {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
kernel_node_ = kernel_node;
|
kernel_node_ = kernel_node;
|
||||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
if (input_num != 1) {
|
if (input_num != 1) {
|
||||||
|
@ -173,6 +174,7 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
|
||||||
MS_LOG(EXCEPTION) << "range_max_ failed to cast";
|
MS_LOG(EXCEPTION) << "range_max_ failed to cast";
|
||||||
}
|
}
|
||||||
range = static_cast<S>(range_max_);
|
range = static_cast<S>(range_max_);
|
||||||
|
MS_EXCEPTION_IF_ZERO("range", range);
|
||||||
return static_cast<S>(1.0f / range);
|
return static_cast<S>(1.0f / range);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -42,6 +42,7 @@ void AdderFusion::set_activation_type(const ActivationType activation_type) {
|
||||||
|
|
||||||
ActivationType AdderFusion::get_activation_type() const {
|
ActivationType AdderFusion::get_activation_type() const {
|
||||||
auto value_ptr = GetAttr(kActivationType);
|
auto value_ptr = GetAttr(kActivationType);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNameAdderFusion, AdderFusion);
|
REGISTER_PRIMITIVE_C(kNameAdderFusion, AdderFusion);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -31,9 +31,21 @@ void ArgMaxFusion::set_out_max_value(const bool out_max_value) {
|
||||||
}
|
}
|
||||||
void ArgMaxFusion::set_top_k(const int64_t top_k) { (void)this->AddAttr(kTopK, MakeValue(top_k)); }
|
void ArgMaxFusion::set_top_k(const int64_t top_k) { (void)this->AddAttr(kTopK, MakeValue(top_k)); }
|
||||||
|
|
||||||
bool ArgMaxFusion::get_keep_dims() const { return GetValue<bool>(GetAttr(kKeepDims)); }
|
bool ArgMaxFusion::get_keep_dims() const {
|
||||||
bool ArgMaxFusion::get_out_max_value() const { return GetValue<bool>(GetAttr(kOutMaxValue)); }
|
auto keep_dims = GetAttr(kKeepDims);
|
||||||
int64_t ArgMaxFusion::get_top_k() const { return GetValue<int64_t>(GetAttr(kTopK)); }
|
MS_EXCEPTION_IF_NULL(keep_dims);
|
||||||
|
return GetValue<bool>(keep_dims);
|
||||||
|
}
|
||||||
|
bool ArgMaxFusion::get_out_max_value() const {
|
||||||
|
auto out_maxv = GetAttr(kOutMaxValue);
|
||||||
|
MS_EXCEPTION_IF_NULL(out_maxv);
|
||||||
|
return GetValue<bool>(out_maxv);
|
||||||
|
}
|
||||||
|
int64_t ArgMaxFusion::get_top_k() const {
|
||||||
|
auto topk = GetAttr(kTopK);
|
||||||
|
MS_EXCEPTION_IF_NULL(topk);
|
||||||
|
return GetValue<int64_t>(topk);
|
||||||
|
}
|
||||||
|
|
||||||
REGISTER_PRIMITIVE_C(kNameArgMaxFusion, ArgMaxFusion);
|
REGISTER_PRIMITIVE_C(kNameArgMaxFusion, ArgMaxFusion);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -29,14 +29,21 @@ void ArgMinFusion::set_keep_dims(const bool keep_dims) { (void)this->AddAttr(kKe
|
||||||
void ArgMinFusion::set_out_max_value(bool out_max_value) { (void)AddAttr(kOutMaxValue, MakeValue(out_max_value)); }
|
void ArgMinFusion::set_out_max_value(bool out_max_value) { (void)AddAttr(kOutMaxValue, MakeValue(out_max_value)); }
|
||||||
void ArgMinFusion::set_top_k(int64_t top_k) { (void)this->AddAttr(kTopK, MakeValue(top_k)); }
|
void ArgMinFusion::set_top_k(int64_t top_k) { (void)this->AddAttr(kTopK, MakeValue(top_k)); }
|
||||||
|
|
||||||
bool ArgMinFusion::get_keep_dims() const { return GetValue<bool>(GetAttr(kKeepDims)); }
|
bool ArgMinFusion::get_keep_dims() const {
|
||||||
|
auto keep_dims = GetAttr(kKeepDims);
|
||||||
|
MS_EXCEPTION_IF_NULL(keep_dims);
|
||||||
|
return GetValue<bool>(keep_dims);
|
||||||
|
}
|
||||||
|
|
||||||
bool ArgMinFusion::get_out_max_value() const {
|
bool ArgMinFusion::get_out_max_value() const {
|
||||||
auto value_ptr = GetAttr(kOutMaxValue);
|
auto value_ptr = GetAttr(kOutMaxValue);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<bool>(value_ptr);
|
return GetValue<bool>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t ArgMinFusion::get_top_k() const {
|
int64_t ArgMinFusion::get_top_k() const {
|
||||||
auto value_ptr = GetAttr(kTopK);
|
auto value_ptr = GetAttr(kTopK);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -40,17 +40,22 @@ void AvgPoolFusion::set_activation_type(ActivationType activation_type) {
|
||||||
|
|
||||||
bool AvgPoolFusion::get_global() const {
|
bool AvgPoolFusion::get_global() const {
|
||||||
auto value_ptr = GetAttr(kGlobal);
|
auto value_ptr = GetAttr(kGlobal);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<bool>(value_ptr);
|
return GetValue<bool>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
ActivationType AvgPoolFusion::get_activation_type() const {
|
ActivationType AvgPoolFusion::get_activation_type() const {
|
||||||
auto value_ptr = GetAttr(kActivationType);
|
auto value_ptr = GetAttr(kActivationType);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
for (auto item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
auto op_name = primitive->name();
|
auto op_name = primitive->name();
|
||||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||||
|
@ -66,6 +71,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
||||||
auto in_w = in_shape[3];
|
auto in_w = in_shape[3];
|
||||||
|
|
||||||
auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
|
auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
|
||||||
|
(void)CheckAndConvertUtils::CheckPositiveVector(kStride, strides, op_name);
|
||||||
auto kernel_h = kernel_size[2];
|
auto kernel_h = kernel_size[2];
|
||||||
auto kernel_w = kernel_size[3];
|
auto kernel_w = kernel_size[3];
|
||||||
auto stride_h = strides[2];
|
auto stride_h = strides[2];
|
||||||
|
@ -90,8 +96,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr arg) { return arg == nullptr; })) {
|
for (auto item : input_args) {
|
||||||
MS_LOG(EXCEPTION) << "nullptr";
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
}
|
}
|
||||||
return input_args[0]->BuildType();
|
return input_args[0]->BuildType();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -52,11 +52,13 @@ void Conv2DBackpropFilterFusion::set_in_channel(const int64_t in_channel) {
|
||||||
|
|
||||||
ActivationType Conv2DBackpropFilterFusion::get_activation_type() const {
|
ActivationType Conv2DBackpropFilterFusion::get_activation_type() const {
|
||||||
auto value_ptr = GetAttr(kActivationType);
|
auto value_ptr = GetAttr(kActivationType);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t Conv2DBackpropFilterFusion::get_in_channel() const {
|
int64_t Conv2DBackpropFilterFusion::get_in_channel() const {
|
||||||
auto value_ptr = GetAttr(kInChannel);
|
auto value_ptr = GetAttr(kInChannel);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNameConv2DBackpropFilterFusion, Conv2DBackpropFilterFusion);
|
REGISTER_PRIMITIVE_C(kNameConv2DBackpropFilterFusion, Conv2DBackpropFilterFusion);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -51,11 +51,13 @@ void Conv2DBackpropInputFusion::set_activation_type(const ActivationType &activa
|
||||||
}
|
}
|
||||||
int64_t Conv2DBackpropInputFusion::get_in_channel() const {
|
int64_t Conv2DBackpropInputFusion::get_in_channel() const {
|
||||||
auto value_ptr = GetAttr(kInChannel);
|
auto value_ptr = GetAttr(kInChannel);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
ActivationType Conv2DBackpropInputFusion::get_activation_type() const {
|
ActivationType Conv2DBackpropInputFusion::get_activation_type() const {
|
||||||
auto value_ptr = GetAttr(kActivationType);
|
auto value_ptr = GetAttr(kActivationType);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNameConv2DBackpropInputFusion, Conv2DBackpropInputFusion);
|
REGISTER_PRIMITIVE_C(kNameConv2DBackpropInputFusion, Conv2DBackpropInputFusion);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -48,14 +48,17 @@ void Conv2DFusion::set_activation_type(const ActivationType &activation_type) {
|
||||||
}
|
}
|
||||||
int64_t Conv2DFusion::get_in_channel() const {
|
int64_t Conv2DFusion::get_in_channel() const {
|
||||||
auto value_ptr = GetAttr(kInChannel);
|
auto value_ptr = GetAttr(kInChannel);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
std::vector<int64_t> Conv2DFusion::get_pad_list() const {
|
std::vector<int64_t> Conv2DFusion::get_pad_list() const {
|
||||||
auto value_ptr = GetAttr(kPadList);
|
auto value_ptr = GetAttr(kPadList);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
ActivationType Conv2DFusion::get_activation_type() const {
|
ActivationType Conv2DFusion::get_activation_type() const {
|
||||||
auto value_ptr = GetAttr(kActivationType);
|
auto value_ptr = GetAttr(kActivationType);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNameConv2DFusion, Conv2DFusion);
|
REGISTER_PRIMITIVE_C(kNameConv2DFusion, Conv2DFusion);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -71,11 +71,13 @@ void Conv2dTransposeFusion::set_activation_type(ActivationType activation_type)
|
||||||
|
|
||||||
std::vector<int64_t> Conv2dTransposeFusion::get_output_paddings() const {
|
std::vector<int64_t> Conv2dTransposeFusion::get_output_paddings() const {
|
||||||
auto value_ptr = GetAttr(kOutputPaddings);
|
auto value_ptr = GetAttr(kOutputPaddings);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
ActivationType Conv2dTransposeFusion::get_activation_type() const {
|
ActivationType Conv2dTransposeFusion::get_activation_type() const {
|
||||||
auto value_ptr = GetAttr(kActivationType);
|
auto value_ptr = GetAttr(kActivationType);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -29,6 +29,7 @@ void DivFusion::set_activation_type(const ActivationType &activation_type) {
|
||||||
|
|
||||||
ActivationType DivFusion::get_activation_type() const {
|
ActivationType DivFusion::get_activation_type() const {
|
||||||
auto value_ptr = GetAttr(kActivationType);
|
auto value_ptr = GetAttr(kActivationType);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNameDivFusion, DivFusion);
|
REGISTER_PRIMITIVE_C(kNameDivFusion, DivFusion);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -23,6 +23,7 @@ namespace ops {
|
||||||
void EmbeddingLookupFusion::set_max_norm(const float max_norm) { (void)this->AddAttr(kMaxNorm, MakeValue(max_norm)); }
|
void EmbeddingLookupFusion::set_max_norm(const float max_norm) { (void)this->AddAttr(kMaxNorm, MakeValue(max_norm)); }
|
||||||
float EmbeddingLookupFusion::get_max_norm() const {
|
float EmbeddingLookupFusion::get_max_norm() const {
|
||||||
auto value_ptr = GetAttr(kMaxNorm);
|
auto value_ptr = GetAttr(kMaxNorm);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<float>(value_ptr);
|
return GetValue<float>(value_ptr);
|
||||||
}
|
}
|
||||||
void EmbeddingLookupFusion::Init(const float max_norm) { this->set_max_norm(max_norm); }
|
void EmbeddingLookupFusion::Init(const float max_norm) { this->set_max_norm(max_norm); }
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -37,14 +37,17 @@ void ExpFusion::set_shift(const float shift) { (void)this->AddAttr(kShift, MakeV
|
||||||
|
|
||||||
float ExpFusion::get_base() const {
|
float ExpFusion::get_base() const {
|
||||||
auto value_ptr = GetAttr(kBase);
|
auto value_ptr = GetAttr(kBase);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<float>(value_ptr);
|
return GetValue<float>(value_ptr);
|
||||||
}
|
}
|
||||||
float ExpFusion::get_scale() const {
|
float ExpFusion::get_scale() const {
|
||||||
auto value_ptr = GetAttr(kScale);
|
auto value_ptr = GetAttr(kScale);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<float>(value_ptr);
|
return GetValue<float>(value_ptr);
|
||||||
}
|
}
|
||||||
float ExpFusion::get_shift() const {
|
float ExpFusion::get_shift() const {
|
||||||
auto value_ptr = GetAttr(kShift);
|
auto value_ptr = GetAttr(kShift);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<float>(value_ptr);
|
return GetValue<float>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -22,13 +22,25 @@ namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
void FullConnection::set_has_bias(const bool has_bias) { (void)this->AddAttr(kHasBias, MakeValue(has_bias)); }
|
void FullConnection::set_has_bias(const bool has_bias) { (void)this->AddAttr(kHasBias, MakeValue(has_bias)); }
|
||||||
|
|
||||||
bool FullConnection::get_has_bias() const { return GetValue<bool>(GetAttr(kHasBias)); }
|
bool FullConnection::get_has_bias() const {
|
||||||
|
auto value_ptr = GetAttr(kHasBias);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
|
return GetValue<bool>(value_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
void FullConnection::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, MakeValue(axis)); }
|
void FullConnection::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, MakeValue(axis)); }
|
||||||
int64_t FullConnection::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
|
int64_t FullConnection::get_axis() const {
|
||||||
|
auto value_ptr = GetAttr(kAxis);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
|
return GetValue<int64_t>(value_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
void FullConnection::set_use_axis(const bool use_axis) { (void)this->AddAttr(kUseAxis, MakeValue(use_axis)); }
|
void FullConnection::set_use_axis(const bool use_axis) { (void)this->AddAttr(kUseAxis, MakeValue(use_axis)); }
|
||||||
bool FullConnection::get_use_axis() const { return GetValue<bool>(GetAttr(kUseAxis)); }
|
bool FullConnection::get_use_axis() const {
|
||||||
|
auto value_ptr = GetAttr(kUseAxis);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
|
return GetValue<bool>(value_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
void FullConnection::set_activation_type(const ActivationType &activation_type) {
|
void FullConnection::set_activation_type(const ActivationType &activation_type) {
|
||||||
int64_t swi = activation_type;
|
int64_t swi = activation_type;
|
||||||
|
@ -36,6 +48,7 @@ void FullConnection::set_activation_type(const ActivationType &activation_type)
|
||||||
}
|
}
|
||||||
ActivationType FullConnection::get_activation_type() const {
|
ActivationType FullConnection::get_activation_type() const {
|
||||||
auto value_ptr = GetAttr(kActivationType);
|
auto value_ptr = GetAttr(kActivationType);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
void FullConnection::Init(const bool has_bias, const int64_t axis, const bool use_axis,
|
void FullConnection::Init(const bool has_bias, const int64_t axis, const bool use_axis,
|
||||||
|
@ -57,14 +70,16 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
|
||||||
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input1->BuildShape())[kShape];
|
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input1->BuildShape())[kShape];
|
||||||
auto prim_axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
|
auto prim_axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
|
||||||
auto has_bias = GetValue<bool>(primitive->GetAttr(kHasBias));
|
auto has_bias = GetValue<bool>(primitive->GetAttr(kHasBias));
|
||||||
|
const int64_t input_num_bias = 3;
|
||||||
|
const int64_t input_num = 2;
|
||||||
if (has_bias) {
|
if (has_bias) {
|
||||||
(void)CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, 3, prim_name);
|
(void)CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, input_num_bias, prim_name);
|
||||||
} else {
|
} else {
|
||||||
(void)CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, 2, prim_name);
|
(void)CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, input_num, prim_name);
|
||||||
}
|
}
|
||||||
auto use_axis = GetValue<bool>(primitive->GetAttr(kUseAxis));
|
auto use_axis = GetValue<bool>(primitive->GetAttr(kUseAxis));
|
||||||
if (use_axis && (prim_axis < 1 || prim_axis > (int64_t)input0_shape.size())) {
|
if (use_axis && (prim_axis < 1 || prim_axis > (int64_t)input0_shape.size())) {
|
||||||
MS_EXCEPTION(ValueError) << "Full Connection axis invalid";
|
MS_EXCEPTION(ValueError) << "Full Connection axis is invalid";
|
||||||
}
|
}
|
||||||
int64_t new_k = 1;
|
int64_t new_k = 1;
|
||||||
if (use_axis) {
|
if (use_axis) {
|
||||||
|
@ -72,7 +87,7 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
|
||||||
new_k *= input0_shape[t];
|
new_k *= input0_shape[t];
|
||||||
}
|
}
|
||||||
if (new_k != input1_shape[1]) {
|
if (new_k != input1_shape[1]) {
|
||||||
MS_EXCEPTION(ValueError) << "Input1 size invalid";
|
MS_EXCEPTION(ValueError) << "Input1 size is invalid";
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
new_k = input1_shape[1];
|
new_k = input1_shape[1];
|
||||||
|
@ -80,7 +95,7 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
|
||||||
if (has_bias) {
|
if (has_bias) {
|
||||||
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||||
if (input2_shape[0] != input1_shape[0]) {
|
if (input2_shape[0] != input1_shape[0]) {
|
||||||
MS_EXCEPTION(ValueError) << "Bias size invalid";
|
MS_EXCEPTION(ValueError) << "Bias size is invalid";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::vector<int64_t> out_shape = {(int64_t)input0_shape.size()};
|
std::vector<int64_t> out_shape = {(int64_t)input0_shape.size()};
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -34,6 +34,7 @@ void L2NormalizeFusion::set_activation_type(const ActivationType &activation_typ
|
||||||
|
|
||||||
ActivationType L2NormalizeFusion::get_activation_type() const {
|
ActivationType L2NormalizeFusion::get_activation_type() const {
|
||||||
auto value_ptr = GetAttr(kActivationType);
|
auto value_ptr = GetAttr(kActivationType);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNameL2NormalizeFusion, L2NormalizeFusion);
|
REGISTER_PRIMITIVE_C(kNameL2NormalizeFusion, L2NormalizeFusion);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -32,6 +32,7 @@ void LayerNormFusion::set_elementwise_affine(const bool elementwise_affine) {
|
||||||
|
|
||||||
bool LayerNormFusion::get_elementwise_affine() const {
|
bool LayerNormFusion::get_elementwise_affine() const {
|
||||||
auto value_ptr = GetAttr(kElementwiseAffine);
|
auto value_ptr = GetAttr(kElementwiseAffine);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<bool>(value_ptr);
|
return GetValue<bool>(value_ptr);
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNameLayerNormFusion, LayerNormFusion);
|
REGISTER_PRIMITIVE_C(kNameLayerNormFusion, LayerNormFusion);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -40,24 +40,28 @@ void MaxPoolFusion::set_activation_type(ActivationType activation_type) {
|
||||||
|
|
||||||
bool MaxPoolFusion::get_global() const {
|
bool MaxPoolFusion::get_global() const {
|
||||||
auto value_ptr = GetAttr(kGlobal);
|
auto value_ptr = GetAttr(kGlobal);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<bool>(value_ptr);
|
return GetValue<bool>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
ActivationType MaxPoolFusion::get_activation_type() const {
|
ActivationType MaxPoolFusion::get_activation_type() const {
|
||||||
auto value_ptr = GetAttr(kActivationType);
|
auto value_ptr = GetAttr(kActivationType);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||||
auto op_name = primitive->name();
|
auto op_name = primitive->name();
|
||||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||||
if (format == NHWC) {
|
if (format == NHWC) {
|
||||||
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
|
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
|
||||||
}
|
}
|
||||||
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, 4, op_name);
|
const int64_t in_shape_size = 4;
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, in_shape_size, op_name);
|
||||||
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
|
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
|
||||||
auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
|
auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
|
||||||
auto batch = in_shape[0];
|
auto batch = in_shape[0];
|
||||||
|
@ -84,21 +88,22 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
||||||
out_shape = {batch, out_h, out_w, channel};
|
out_shape = {batch, out_h, out_w, channel};
|
||||||
}
|
}
|
||||||
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
|
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
|
||||||
MS_LOG(EXCEPTION) << "Kernel size is not valid.";
|
MS_LOG(EXCEPTION) << "Kernel size is invalid.";
|
||||||
}
|
}
|
||||||
return std::make_shared<abstract::Shape>(out_shape);
|
return std::make_shared<abstract::Shape>(out_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) {
|
|
||||||
MS_LOG(EXCEPTION) << "nullptr";
|
|
||||||
}
|
|
||||||
return input_args[0]->BuildType();
|
return input_args[0]->BuildType();
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
AbstractBasePtr MaxPoolFusionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr MaxPoolFusionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
for (auto item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||||
InferShape(primitive, input_args)->shape());
|
InferShape(primitive, input_args)->shape());
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -29,6 +29,7 @@ void MulFusion::set_activation_type(const ActivationType &activation_type) {
|
||||||
}
|
}
|
||||||
ActivationType MulFusion::get_activation_type() const {
|
ActivationType MulFusion::get_activation_type() const {
|
||||||
auto value_ptr = GetAttr(kActivationType);
|
auto value_ptr = GetAttr(kActivationType);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
void MulFusion::Init(const ActivationType &activation_type) { this->set_activation_type(activation_type); }
|
void MulFusion::Init(const ActivationType &activation_type) { this->set_activation_type(activation_type); }
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -39,10 +39,12 @@ void PadFusion::set_constant_value(const float constant_value) {
|
||||||
|
|
||||||
PaddingMode PadFusion::get_padding_mode() const {
|
PaddingMode PadFusion::get_padding_mode() const {
|
||||||
auto value_ptr = GetAttr(kPaddingMode);
|
auto value_ptr = GetAttr(kPaddingMode);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return PaddingMode(GetValue<int64_t>(value_ptr));
|
return PaddingMode(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
float PadFusion::get_constant_value() const {
|
float PadFusion::get_constant_value() const {
|
||||||
auto value_ptr = GetAttr(kConstantValue);
|
auto value_ptr = GetAttr(kConstantValue);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<float>(value_ptr);
|
return GetValue<float>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -25,6 +25,7 @@ void PartialFusion::set_sub_graph_index(const int64_t sub_graph_index) {
|
||||||
}
|
}
|
||||||
int64_t PartialFusion::get_sub_graph_index() const {
|
int64_t PartialFusion::get_sub_graph_index() const {
|
||||||
auto value_ptr = GetAttr(kSubGraphIndex);
|
auto value_ptr = GetAttr(kSubGraphIndex);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNamePartialFusion, PartialFusion);
|
REGISTER_PRIMITIVE_C(kNamePartialFusion, PartialFusion);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -42,6 +42,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
for (const auto &item : input_args) {
|
for (const auto &item : input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
}
|
}
|
||||||
|
@ -54,6 +55,10 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
||||||
|
|
||||||
AbstractBasePtr PowFusionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr PowFusionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
const int64_t input_num = 2;
|
||||||
|
CheckAndConvertUtils::CheckInteger("PowFusion infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||||
|
primitive->name());
|
||||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||||
InferShape(primitive, input_args)->shape());
|
InferShape(primitive, input_args)->shape());
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -36,10 +36,12 @@ void PReLUFusion::set_slope(const std::vector<float> &slope) { (void)this->AddAt
|
||||||
|
|
||||||
bool PReLUFusion::get_channel_shared() const {
|
bool PReLUFusion::get_channel_shared() const {
|
||||||
auto value_ptr = GetAttr(kChannelShared);
|
auto value_ptr = GetAttr(kChannelShared);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<bool>(value_ptr);
|
return GetValue<bool>(value_ptr);
|
||||||
}
|
}
|
||||||
std::vector<float> PReLUFusion::get_slope() const {
|
std::vector<float> PReLUFusion::get_slope() const {
|
||||||
auto value_ptr = GetAttr(kSlope);
|
auto value_ptr = GetAttr(kSlope);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<std::vector<float>>(value_ptr);
|
return GetValue<std::vector<float>>(value_ptr);
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNamePReLUFusion, PReLUFusion);
|
REGISTER_PRIMITIVE_C(kNamePReLUFusion, PReLUFusion);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -41,21 +41,25 @@ void ReduceFusion::set_coeff(const float coeff) { (void)this->AddAttr(kCoeff, Ma
|
||||||
|
|
||||||
bool ReduceFusion::get_keep_dims() const {
|
bool ReduceFusion::get_keep_dims() const {
|
||||||
auto value_ptr = GetAttr(kKeepDims);
|
auto value_ptr = GetAttr(kKeepDims);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<bool>(value_ptr);
|
return GetValue<bool>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
ReduceMode ReduceFusion::get_mode() const {
|
ReduceMode ReduceFusion::get_mode() const {
|
||||||
auto value_ptr = GetAttr(kMode);
|
auto value_ptr = GetAttr(kMode);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return ReduceMode(GetValue<int64_t>(value_ptr));
|
return ReduceMode(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ReduceFusion::get_reduce_to_end() const {
|
bool ReduceFusion::get_reduce_to_end() const {
|
||||||
auto value_ptr = GetAttr(kReduceToEnd);
|
auto value_ptr = GetAttr(kReduceToEnd);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<bool>(value_ptr);
|
return GetValue<bool>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
float ReduceFusion::get_coeff() const {
|
float ReduceFusion::get_coeff() const {
|
||||||
auto value_ptr = GetAttr(kCoeff);
|
auto value_ptr = GetAttr(kCoeff);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<float>(value_ptr);
|
return GetValue<float>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -32,6 +32,7 @@ void ScaleFusion::set_activation_type(const ActivationType &activation_type) {
|
||||||
|
|
||||||
ActivationType ScaleFusion::get_activation_type() const {
|
ActivationType ScaleFusion::get_activation_type() const {
|
||||||
auto value_ptr = GetAttr(kActivationType);
|
auto value_ptr = GetAttr(kActivationType);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNameScaleFusion, ScaleFusion);
|
REGISTER_PRIMITIVE_C(kNameScaleFusion, ScaleFusion);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -26,6 +26,7 @@ void SliceFusion::set_axes(const std::vector<int64_t> &axes) { (void)this->AddAt
|
||||||
|
|
||||||
std::vector<int64_t> SliceFusion::get_axes() const {
|
std::vector<int64_t> SliceFusion::get_axes() const {
|
||||||
auto value_ptr = GetAttr(kAxes);
|
auto value_ptr = GetAttr(kAxes);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -39,6 +40,8 @@ AbstractBasePtr SliceFusionInfer(const abstract::AnalysisEnginePtr &, const Prim
|
||||||
auto size_v = input_args[2]->BuildValue();
|
auto size_v = input_args[2]->BuildValue();
|
||||||
auto x_type = input_args[0]->BuildType();
|
auto x_type = input_args[0]->BuildType();
|
||||||
MS_EXCEPTION_IF_NULL(x_type);
|
MS_EXCEPTION_IF_NULL(x_type);
|
||||||
|
MS_EXCEPTION_IF_NULL(begin_v);
|
||||||
|
MS_EXCEPTION_IF_NULL(size_v);
|
||||||
auto tensor_type = x_type->cast<TensorTypePtr>();
|
auto tensor_type = x_type->cast<TensorTypePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||||
auto data_type = tensor_type->element();
|
auto data_type = tensor_type->element();
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -29,6 +29,7 @@ void SubFusion::set_activation_type(const ActivationType &activation_type) {
|
||||||
|
|
||||||
ActivationType SubFusion::get_activation_type() const {
|
ActivationType SubFusion::get_activation_type() const {
|
||||||
auto value_ptr = GetAttr(kActivationType);
|
auto value_ptr = GetAttr(kActivationType);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNameSubFusion, SubFusion);
|
REGISTER_PRIMITIVE_C(kNameSubFusion, SubFusion);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -26,6 +26,7 @@ void TileFusion::set_dims(const std::vector<int64_t> &dims) { (void)this->AddAtt
|
||||||
|
|
||||||
std::vector<int64_t> TileFusion::get_dims() const {
|
std::vector<int64_t> TileFusion::get_dims() const {
|
||||||
auto value_ptr = GetAttr(kDims);
|
auto value_ptr = GetAttr(kDims);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNameTileFusion, TileFusion);
|
REGISTER_PRIMITIVE_C(kNameTileFusion, TileFusion);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -32,11 +32,13 @@ void TopKFusion::set_largest(const int64_t largest) { (void)this->AddAttr(kLarge
|
||||||
|
|
||||||
int64_t TopKFusion::get_axis() const {
|
int64_t TopKFusion::get_axis() const {
|
||||||
auto value_ptr = GetAttr(kAxis);
|
auto value_ptr = GetAttr(kAxis);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t TopKFusion::get_largest() const {
|
int64_t TopKFusion::get_largest() const {
|
||||||
auto value_ptr = GetAttr(kLargest);
|
auto value_ptr = GetAttr(kLargest);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNameTopKFusion, TopKFusion);
|
REGISTER_PRIMITIVE_C(kNameTopKFusion, TopKFusion);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -38,6 +38,7 @@ void ActivationGrad::set_activation_type(const ActivationType &type) {
|
||||||
|
|
||||||
ActivationType ActivationGrad::get_activation_type() const {
|
ActivationType ActivationGrad::get_activation_type() const {
|
||||||
auto value_ptr = GetAttr(kActivationType);
|
auto value_ptr = GetAttr(kActivationType);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -45,6 +46,7 @@ void ActivationGrad::set_alpha(const float alpha) { (void)this->AddAttr(kAlpha,
|
||||||
|
|
||||||
float ActivationGrad::get_alpha() const {
|
float ActivationGrad::get_alpha() const {
|
||||||
auto value_ptr = GetAttr(kAlpha);
|
auto value_ptr = GetAttr(kAlpha);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<float>(value_ptr);
|
return GetValue<float>(value_ptr);
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNameActivationGrad, ActivationGrad);
|
REGISTER_PRIMITIVE_C(kNameActivationGrad, ActivationGrad);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -31,6 +31,7 @@ void BatchNormGrad::set_epsilon(const float epsilon) { (void)this->AddAttr(kEpsi
|
||||||
|
|
||||||
float BatchNormGrad::get_epsilon() const {
|
float BatchNormGrad::get_epsilon() const {
|
||||||
auto value_ptr = this->GetAttr(kEpsilon);
|
auto value_ptr = this->GetAttr(kEpsilon);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<float>(value_ptr);
|
return GetValue<float>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,15 +39,19 @@ void BatchNormGrad::set_is_training(const bool is_training) { this->AddAttr(kIsT
|
||||||
|
|
||||||
bool BatchNormGrad::get_is_training() const {
|
bool BatchNormGrad::get_is_training() const {
|
||||||
auto value_ptr = this->GetAttr(kIsTraining);
|
auto value_ptr = this->GetAttr(kIsTraining);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<bool>(value_ptr);
|
return GetValue<bool>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr BatchNormGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr BatchNormGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
MS_EXCEPTION_IF_NULL(input_args[1]);
|
for (auto item : input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(input_args[2]);
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
MS_EXCEPTION_IF_NULL(input_args[3]);
|
}
|
||||||
|
const int64_t input_num = 5;
|
||||||
|
CheckAndConvertUtils::CheckInteger("BatchNormGrad infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||||
|
primitive->name());
|
||||||
auto y_backprop_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
auto y_backprop_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||||
CheckAndConvertUtils::Check("BatchNorm y_backprop_shape", y_backprop_shape, kEqual, "BatchNorm x_shape", x_shape);
|
CheckAndConvertUtils::Check("BatchNorm y_backprop_shape", y_backprop_shape, kEqual, "BatchNorm x_shape", x_shape);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -38,9 +38,6 @@ abstract::ShapePtr BinaryCrossEntroyGradInferShape(const PrimitivePtr &primitive
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr BinaryCrossEntroyGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr BinaryCrossEntroyGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
for (const auto &item : input_args) {
|
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
|
||||||
}
|
|
||||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||||
std::map<std::string, TypePtr> types;
|
std::map<std::string, TypePtr> types;
|
||||||
types.emplace("x_shape", input_args[0]->BuildType());
|
types.emplace("x_shape", input_args[0]->BuildType());
|
||||||
|
@ -67,6 +64,13 @@ Reduction BinaryCrossEntropyGrad::get_reduction() const {
|
||||||
|
|
||||||
AbstractBasePtr BinaryCrossEntropyGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr BinaryCrossEntropyGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
for (const auto &item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
|
const int64_t input_num = 4;
|
||||||
|
CheckAndConvertUtils::CheckInteger("BinaryCrossEntropyGrad infer", SizeToLong(input_args.size()), kGreaterEqual,
|
||||||
|
input_num, primitive->name());
|
||||||
return std::make_shared<abstract::AbstractTensor>(BinaryCrossEntroyGradInferType(primitive, input_args),
|
return std::make_shared<abstract::AbstractTensor>(BinaryCrossEntroyGradInferType(primitive, input_args),
|
||||||
BinaryCrossEntroyGradInferShape(primitive, input_args)->shape());
|
BinaryCrossEntroyGradInferShape(primitive, input_args)->shape());
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -30,6 +30,7 @@ void BNGrad::set_eps(const float eps) { (void)this->AddAttr(kEps, MakeValue(eps)
|
||||||
|
|
||||||
float BNGrad::get_eps() const {
|
float BNGrad::get_eps() const {
|
||||||
auto value_ptr = this->GetAttr(kEps);
|
auto value_ptr = this->GetAttr(kEps);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<float>(value_ptr);
|
return GetValue<float>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -37,6 +38,7 @@ void BNGrad::set_momentum(const float momentum) { (void)this->AddAttr(kMomentum,
|
||||||
|
|
||||||
float BNGrad::get_momentum() const {
|
float BNGrad::get_momentum() const {
|
||||||
auto value_ptr = this->GetAttr(kMomentum);
|
auto value_ptr = this->GetAttr(kMomentum);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<float>(value_ptr);
|
return GetValue<float>(value_ptr);
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNameBNGrad, BNGrad);
|
REGISTER_PRIMITIVE_C(kNameBNGrad, BNGrad);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -71,6 +71,7 @@ void Conv2DBackpropFilter::set_out_channel(const int64_t out_channel) {
|
||||||
|
|
||||||
int64_t Conv2DBackpropFilter::get_out_channel() const {
|
int64_t Conv2DBackpropFilter::get_out_channel() const {
|
||||||
auto value_ptr = GetAttr(kOutChannel);
|
auto value_ptr = GetAttr(kOutChannel);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,6 +81,7 @@ void Conv2DBackpropFilter::set_kernel_size(const std::vector<int64_t> &kernel_si
|
||||||
|
|
||||||
std::vector<int64_t> Conv2DBackpropFilter::get_kernel_size() const {
|
std::vector<int64_t> Conv2DBackpropFilter::get_kernel_size() const {
|
||||||
auto value_ptr = GetAttr(kKernelSize);
|
auto value_ptr = GetAttr(kKernelSize);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,6 +92,7 @@ void Conv2DBackpropFilter::set_pad_mode(const PadMode &pad_mode) {
|
||||||
|
|
||||||
PadMode Conv2DBackpropFilter::get_pad_mode() const {
|
PadMode Conv2DBackpropFilter::get_pad_mode() const {
|
||||||
auto value_ptr = GetAttr(kPadMode);
|
auto value_ptr = GetAttr(kPadMode);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return PadMode(GetValue<int64_t>(value_ptr));
|
return PadMode(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -99,6 +102,7 @@ void Conv2DBackpropFilter::set_pad_list(const std::vector<int64_t> &pad_list) {
|
||||||
|
|
||||||
std::vector<int64_t> Conv2DBackpropFilter::get_pad_list() const {
|
std::vector<int64_t> Conv2DBackpropFilter::get_pad_list() const {
|
||||||
auto value_ptr = GetAttr(kPadList);
|
auto value_ptr = GetAttr(kPadList);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -106,6 +110,7 @@ void Conv2DBackpropFilter::set_mode(const int64_t mode) { (void)this->AddAttr(kM
|
||||||
|
|
||||||
int64_t Conv2DBackpropFilter::get_mode() const {
|
int64_t Conv2DBackpropFilter::get_mode() const {
|
||||||
auto value_ptr = GetAttr(kMode);
|
auto value_ptr = GetAttr(kMode);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,6 +118,7 @@ void Conv2DBackpropFilter::set_stride(const std::vector<int64_t> &stride) { this
|
||||||
|
|
||||||
std::vector<int64_t> Conv2DBackpropFilter::get_stride() const {
|
std::vector<int64_t> Conv2DBackpropFilter::get_stride() const {
|
||||||
auto value_ptr = GetAttr(kStride);
|
auto value_ptr = GetAttr(kStride);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -122,6 +128,7 @@ void Conv2DBackpropFilter::set_dilation(const std::vector<int64_t> &dilation) {
|
||||||
|
|
||||||
std::vector<int64_t> Conv2DBackpropFilter::get_dilation() const {
|
std::vector<int64_t> Conv2DBackpropFilter::get_dilation() const {
|
||||||
auto value_ptr = GetAttr(kDilation);
|
auto value_ptr = GetAttr(kDilation);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -129,6 +136,7 @@ void Conv2DBackpropFilter::set_group(const int64_t group) { (void)this->AddAttr(
|
||||||
|
|
||||||
int64_t Conv2DBackpropFilter::get_group() const {
|
int64_t Conv2DBackpropFilter::get_group() const {
|
||||||
auto value_ptr = GetAttr(kGroup);
|
auto value_ptr = GetAttr(kGroup);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -139,6 +147,7 @@ void Conv2DBackpropFilter::set_format(const Format &format) {
|
||||||
|
|
||||||
Format Conv2DBackpropFilter::get_format() const {
|
Format Conv2DBackpropFilter::get_format() const {
|
||||||
auto value_ptr = GetAttr(kFormat);
|
auto value_ptr = GetAttr(kFormat);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return Format(GetValue<int64_t>(value_ptr));
|
return Format(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -172,51 +172,61 @@ void Conv2DBackpropInput::set_pad_list(const std::vector<int64_t> &pad_list) {
|
||||||
|
|
||||||
int64_t Conv2DBackpropInput::get_out_channel() const {
|
int64_t Conv2DBackpropInput::get_out_channel() const {
|
||||||
auto value_ptr = GetAttr(kOutChannel);
|
auto value_ptr = GetAttr(kOutChannel);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> Conv2DBackpropInput::get_kernel_size() const {
|
std::vector<int64_t> Conv2DBackpropInput::get_kernel_size() const {
|
||||||
auto value_ptr = GetAttr(kKernelSize);
|
auto value_ptr = GetAttr(kKernelSize);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> Conv2DBackpropInput::get_stride() const {
|
std::vector<int64_t> Conv2DBackpropInput::get_stride() const {
|
||||||
auto value_ptr = GetAttr(kStride);
|
auto value_ptr = GetAttr(kStride);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> Conv2DBackpropInput::get_dilation() const {
|
std::vector<int64_t> Conv2DBackpropInput::get_dilation() const {
|
||||||
auto value_ptr = GetAttr(kDilation);
|
auto value_ptr = GetAttr(kDilation);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
PadMode Conv2DBackpropInput::get_pad_mode() const {
|
PadMode Conv2DBackpropInput::get_pad_mode() const {
|
||||||
auto value_ptr = GetAttr(kPadMode);
|
auto value_ptr = GetAttr(kPadMode);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return PadMode(GetValue<int64_t>(value_ptr));
|
return PadMode(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> Conv2DBackpropInput::get_pad() const {
|
std::vector<int64_t> Conv2DBackpropInput::get_pad() const {
|
||||||
auto value_ptr = GetAttr(kPad);
|
auto value_ptr = GetAttr(kPad);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t Conv2DBackpropInput::get_mode() const {
|
int64_t Conv2DBackpropInput::get_mode() const {
|
||||||
auto value_ptr = GetAttr(kMode);
|
auto value_ptr = GetAttr(kMode);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t Conv2DBackpropInput::get_group() const {
|
int64_t Conv2DBackpropInput::get_group() const {
|
||||||
auto value_ptr = GetAttr(kGroup);
|
auto value_ptr = GetAttr(kGroup);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
Format Conv2DBackpropInput::get_format() const {
|
Format Conv2DBackpropInput::get_format() const {
|
||||||
auto value_ptr = GetAttr(kFormat);
|
auto value_ptr = GetAttr(kFormat);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return Format(GetValue<int64_t>(value_ptr));
|
return Format(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> Conv2DBackpropInput::get_pad_list() const {
|
std::vector<int64_t> Conv2DBackpropInput::get_pad_list() const {
|
||||||
auto value_ptr = GetAttr(kPadList);
|
auto value_ptr = GetAttr(kPadList);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropInput, prim::kPrimConv2DBackpropInput, Conv2DBackpropInputInfer, nullptr,
|
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropInput, prim::kPrimConv2DBackpropInput, Conv2DBackpropInputInfer, nullptr,
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -45,6 +45,7 @@ void DeConv2DGradFilter::set_in_channel(const int64_t in_channel) {
|
||||||
|
|
||||||
int64_t DeConv2DGradFilter::get_in_channel() const {
|
int64_t DeConv2DGradFilter::get_in_channel() const {
|
||||||
auto value_ptr = GetAttr(kInChannel);
|
auto value_ptr = GetAttr(kInChannel);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,6 +55,7 @@ void DeConv2DGradFilter::set_out_channel(const int64_t out_channel) {
|
||||||
|
|
||||||
int64_t DeConv2DGradFilter::get_out_channel() const {
|
int64_t DeConv2DGradFilter::get_out_channel() const {
|
||||||
auto value_ptr = GetAttr(kOutChannel);
|
auto value_ptr = GetAttr(kOutChannel);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,6 +65,7 @@ void DeConv2DGradFilter::set_kernel_size(const std::vector<int64_t> &kernel_size
|
||||||
|
|
||||||
std::vector<int64_t> DeConv2DGradFilter::get_kernel_size() const {
|
std::vector<int64_t> DeConv2DGradFilter::get_kernel_size() const {
|
||||||
auto value_ptr = GetAttr(kKernelSize);
|
auto value_ptr = GetAttr(kKernelSize);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,6 +76,7 @@ void DeConv2DGradFilter::set_pad_mode(const PadMode &pad_mode) {
|
||||||
|
|
||||||
PadMode DeConv2DGradFilter::get_pad_mode() const {
|
PadMode DeConv2DGradFilter::get_pad_mode() const {
|
||||||
auto value_ptr = GetAttr(kPadMode);
|
auto value_ptr = GetAttr(kPadMode);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return PadMode(GetValue<int64_t>(value_ptr));
|
return PadMode(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,6 +86,7 @@ void DeConv2DGradFilter::set_pad_list(const std::vector<int64_t> &pad_list) {
|
||||||
|
|
||||||
std::vector<int64_t> DeConv2DGradFilter::get_pad_list() const {
|
std::vector<int64_t> DeConv2DGradFilter::get_pad_list() const {
|
||||||
auto value_ptr = GetAttr(kPadList);
|
auto value_ptr = GetAttr(kPadList);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -91,6 +96,7 @@ void DeConv2DGradFilter::set_stride(const std::vector<int64_t> &stride) {
|
||||||
|
|
||||||
std::vector<int64_t> DeConv2DGradFilter::get_stride() const {
|
std::vector<int64_t> DeConv2DGradFilter::get_stride() const {
|
||||||
auto value_ptr = GetAttr(kStride);
|
auto value_ptr = GetAttr(kStride);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -100,6 +106,7 @@ void DeConv2DGradFilter::set_dilation(const std::vector<int64_t> &dilation) {
|
||||||
|
|
||||||
std::vector<int64_t> DeConv2DGradFilter::get_dilation() const {
|
std::vector<int64_t> DeConv2DGradFilter::get_dilation() const {
|
||||||
auto value_ptr = GetAttr(kDilation);
|
auto value_ptr = GetAttr(kDilation);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -107,6 +114,7 @@ void DeConv2DGradFilter::set_group(const int64_t group) { (void)this->AddAttr(kG
|
||||||
|
|
||||||
int64_t DeConv2DGradFilter::get_group() const {
|
int64_t DeConv2DGradFilter::get_group() const {
|
||||||
auto value_ptr = GetAttr(kGroup);
|
auto value_ptr = GetAttr(kGroup);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,6 +125,7 @@ void DeConv2DGradFilter::set_format(const Format &format) {
|
||||||
|
|
||||||
Format DeConv2DGradFilter::get_format() const {
|
Format DeConv2DGradFilter::get_format() const {
|
||||||
auto value_ptr = GetAttr(kFormat);
|
auto value_ptr = GetAttr(kFormat);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return Format(GetValue<int64_t>(value_ptr));
|
return Format(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -127,6 +136,7 @@ void DeConv2DGradFilter::set_activation_type(const ActivationType &activation_ty
|
||||||
|
|
||||||
ActivationType DeConv2DGradFilter::get_activation_type() const {
|
ActivationType DeConv2DGradFilter::get_activation_type() const {
|
||||||
auto value_ptr = GetAttr(kActivationType);
|
auto value_ptr = GetAttr(kActivationType);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -134,6 +144,7 @@ void DeConv2DGradFilter::set_has_bias(const bool has_bias) { (void)this->AddAttr
|
||||||
|
|
||||||
bool DeConv2DGradFilter::get_has_bias() const {
|
bool DeConv2DGradFilter::get_has_bias() const {
|
||||||
auto value_ptr = GetAttr(kHasBias);
|
auto value_ptr = GetAttr(kHasBias);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<bool>(value_ptr);
|
return GetValue<bool>(value_ptr);
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNameDeConv2DGradFilter, DeConv2DGradFilter);
|
REGISTER_PRIMITIVE_C(kNameDeConv2DGradFilter, DeConv2DGradFilter);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -28,6 +28,7 @@ void DropoutGrad::set_keep_prob(const float keep_prob) {
|
||||||
|
|
||||||
float DropoutGrad::get_keep_prob() const {
|
float DropoutGrad::get_keep_prob() const {
|
||||||
auto value_ptr = GetAttr(kKeepProb);
|
auto value_ptr = GetAttr(kKeepProb);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<float>(value_ptr);
|
return GetValue<float>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,6 +56,13 @@ TypePtr DropoutGradInferType(const PrimitivePtr &prim, const std::vector<Abstrac
|
||||||
|
|
||||||
AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
for (auto item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
|
const int64_t input_num = 2;
|
||||||
|
CheckAndConvertUtils::CheckInteger("DropoutGrad infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||||
|
primitive->name());
|
||||||
return std::make_shared<abstract::AbstractTensor>(DropoutGradInferType(primitive, input_args),
|
return std::make_shared<abstract::AbstractTensor>(DropoutGradInferType(primitive, input_args),
|
||||||
DropoutGradInferShape(primitive, input_args)->shape());
|
DropoutGradInferShape(primitive, input_args)->shape());
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -22,7 +22,8 @@ AbstractBasePtr FlattenGradInfer(const abstract::AnalysisEnginePtr &, const Prim
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto prim_name = primitive->name();
|
auto prim_name = primitive->name();
|
||||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
const int64_t input_num = 2;
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||||
for (const auto &item : input_args) {
|
for (const auto &item : input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -46,6 +46,7 @@ void GroupConv2DGradInput::set_in_channel(const int64_t &in_channel) {
|
||||||
|
|
||||||
int64_t GroupConv2DGradInput::get_in_channel() const {
|
int64_t GroupConv2DGradInput::get_in_channel() const {
|
||||||
auto value_ptr = GetAttr(kInChannel);
|
auto value_ptr = GetAttr(kInChannel);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,6 +56,7 @@ void GroupConv2DGradInput::set_out_channel(const int64_t &out_channel) {
|
||||||
|
|
||||||
int64_t GroupConv2DGradInput::get_out_channel() const {
|
int64_t GroupConv2DGradInput::get_out_channel() const {
|
||||||
auto value_ptr = GetAttr(kOutChannel);
|
auto value_ptr = GetAttr(kOutChannel);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,6 +66,7 @@ void GroupConv2DGradInput::set_kernel_size(const std::vector<int64_t> &kernel_si
|
||||||
|
|
||||||
std::vector<int64_t> GroupConv2DGradInput::get_kernel_size() const {
|
std::vector<int64_t> GroupConv2DGradInput::get_kernel_size() const {
|
||||||
auto value_ptr = GetAttr(kKernelSize);
|
auto value_ptr = GetAttr(kKernelSize);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,6 +77,7 @@ void GroupConv2DGradInput::set_pad_mode(const PadMode &pad_mode) {
|
||||||
|
|
||||||
PadMode GroupConv2DGradInput::get_pad_mode() const {
|
PadMode GroupConv2DGradInput::get_pad_mode() const {
|
||||||
auto value_ptr = GetAttr(kPadMode);
|
auto value_ptr = GetAttr(kPadMode);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return PadMode(GetValue<int64_t>(value_ptr));
|
return PadMode(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,6 +87,7 @@ void GroupConv2DGradInput::set_pad_list(const std::vector<int64_t> &pad_list) {
|
||||||
|
|
||||||
std::vector<int64_t> GroupConv2DGradInput::get_pad_list() const {
|
std::vector<int64_t> GroupConv2DGradInput::get_pad_list() const {
|
||||||
auto value_ptr = GetAttr(kPadList);
|
auto value_ptr = GetAttr(kPadList);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,6 +95,7 @@ void GroupConv2DGradInput::set_stride(const std::vector<int64_t> &stride) { this
|
||||||
|
|
||||||
std::vector<int64_t> GroupConv2DGradInput::get_stride() const {
|
std::vector<int64_t> GroupConv2DGradInput::get_stride() const {
|
||||||
auto value_ptr = GetAttr(kStride);
|
auto value_ptr = GetAttr(kStride);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -99,6 +105,7 @@ void GroupConv2DGradInput::set_dilation(const std::vector<int64_t> &dilation) {
|
||||||
|
|
||||||
std::vector<int64_t> GroupConv2DGradInput::get_dilation() const {
|
std::vector<int64_t> GroupConv2DGradInput::get_dilation() const {
|
||||||
auto value_ptr = GetAttr(kDilation);
|
auto value_ptr = GetAttr(kDilation);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -106,6 +113,7 @@ void GroupConv2DGradInput::set_group(const int64_t &group) { (void)this->AddAttr
|
||||||
|
|
||||||
int64_t GroupConv2DGradInput::get_group() const {
|
int64_t GroupConv2DGradInput::get_group() const {
|
||||||
auto value_ptr = GetAttr(kGroup);
|
auto value_ptr = GetAttr(kGroup);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -114,7 +122,9 @@ void GroupConv2DGradInput::set_input_shape(const std::vector<int64_t> &input_sha
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> GroupConv2DGradInput::get_input_shape() const {
|
std::vector<int64_t> GroupConv2DGradInput::get_input_shape() const {
|
||||||
return GetValue<std::vector<int64_t>>(GetAttr(kInputShape));
|
auto value_ptr = GetAttr(kInputShape);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
|
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GroupConv2DGradInput::set_format(const Format &format) {
|
void GroupConv2DGradInput::set_format(const Format &format) {
|
||||||
|
@ -124,6 +134,7 @@ void GroupConv2DGradInput::set_format(const Format &format) {
|
||||||
|
|
||||||
Format GroupConv2DGradInput::get_format() const {
|
Format GroupConv2DGradInput::get_format() const {
|
||||||
auto value_ptr = GetAttr(kFormat);
|
auto value_ptr = GetAttr(kFormat);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return Format(GetValue<int64_t>(value_ptr));
|
return Format(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -134,6 +145,7 @@ void GroupConv2DGradInput::set_activation_type(const ActivationType &activation_
|
||||||
|
|
||||||
ActivationType GroupConv2DGradInput::get_activation_type() const {
|
ActivationType GroupConv2DGradInput::get_activation_type() const {
|
||||||
auto value_ptr = GetAttr(kActivationType);
|
auto value_ptr = GetAttr(kActivationType);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -141,6 +153,7 @@ void GroupConv2DGradInput::set_has_bias(const bool has_bias) { (void)this->AddAt
|
||||||
|
|
||||||
bool GroupConv2DGradInput::get_has_bias() const {
|
bool GroupConv2DGradInput::get_has_bias() const {
|
||||||
auto value_ptr = GetAttr(kHasBias);
|
auto value_ptr = GetAttr(kHasBias);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<bool>(value_ptr);
|
return GetValue<bool>(value_ptr);
|
||||||
}
|
}
|
||||||
AbstractBasePtr GroupConv2DGradInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr GroupConv2DGradInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
@ -152,11 +165,16 @@ AbstractBasePtr GroupConv2DGradInputInfer(const abstract::AnalysisEnginePtr &, c
|
||||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||||
|
|
||||||
// Infer shape
|
// Infer shape
|
||||||
auto shape = GetValue<std::vector<int64_t>>(primitive->GetAttr(kInputShape));
|
auto shape_ptr = primitive->GetAttr(kInputShape);
|
||||||
|
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||||
|
auto shape = GetValue<std::vector<int64_t>>(shape_ptr);
|
||||||
|
|
||||||
// Infer type
|
// Infer type
|
||||||
auto type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
|
auto type_ptr = input_args[0]->BuildType();
|
||||||
|
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||||
|
auto type_tensor_ptr = type_ptr->cast<TensorTypePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(type_tensor_ptr);
|
||||||
|
auto type = type_tensor_ptr->element();
|
||||||
return std::make_shared<abstract::AbstractTensor>(type, shape);
|
return std::make_shared<abstract::AbstractTensor>(type, shape);
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNameGroupConv2DGradInput, GroupConv2DGradInput);
|
REGISTER_PRIMITIVE_C(kNameGroupConv2DGradInput, GroupConv2DGradInput);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -48,10 +48,12 @@ void LayerNormGrad::set_begin_params_axis(const int64_t begin_params_axis) {
|
||||||
}
|
}
|
||||||
int64_t LayerNormGrad::get_begin_norm_axis() const {
|
int64_t LayerNormGrad::get_begin_norm_axis() const {
|
||||||
auto value_ptr = this->GetAttr(kBeginNormAxis);
|
auto value_ptr = this->GetAttr(kBeginNormAxis);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
int64_t LayerNormGrad::get_begin_params_axis() const {
|
int64_t LayerNormGrad::get_begin_params_axis() const {
|
||||||
auto value_ptr = this->GetAttr(kBeginParamsAxis);
|
auto value_ptr = this->GetAttr(kBeginParamsAxis);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_EVAL_IMPL(LayerNormGrad, prim::kPrimLayerNormGrad, LayerNormGradInfer, nullptr, true);
|
REGISTER_PRIMITIVE_EVAL_IMPL(LayerNormGrad, prim::kPrimLayerNormGrad, LayerNormGradInfer, nullptr, true);
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue