forked from mindspore-Ecosystem/mindspore
commit
ad589e6780
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -39,6 +39,7 @@ class RecvGpuKernel : public GpuKernel {
|
|||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
wait_stream_ = reinterpret_cast<cudaStream_t>(GetAttr<uintptr_t>(kernel_node, "wait_event_stream"));
|
||||
wait_event_ = reinterpret_cast<cudaEvent_t>(GetAttr<uintptr_t>(kernel_node, "wait_event"));
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -39,6 +39,7 @@ class SendGpuKernel : public GpuKernel {
|
|||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
record_stream_ = reinterpret_cast<cudaStream_t>(GetAttr<uintptr_t>(kernel_node, "record_event_stream"));
|
||||
record_event_ = reinterpret_cast<cudaEvent_t>(GetAttr<uintptr_t>(kernel_node, "record_event"));
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -34,11 +34,17 @@ const std::vector<size_t> &DatasetInitKernel::GetOutputSizeList() const { return
|
|||
const std::vector<size_t> &DatasetInitKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
|
||||
|
||||
bool DatasetInitKernel::Init(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
queue_name_ = GetAttr<std::string>(kernel_node, "queue_name");
|
||||
std::vector<std::vector<int>> shapes;
|
||||
std::vector<TypePtr> types;
|
||||
GetShapeAndType(kernel_node, &shapes, &types);
|
||||
|
||||
for (auto item : types) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
if (types.size() < shapes.size()) {
|
||||
MS_LOG(EXCEPTION) << "types size is less than shapes size.";
|
||||
}
|
||||
for (size_t i = 0; i < shapes.size(); i++) {
|
||||
int unit = UnitSizeInBytes(types[i]->type_id());
|
||||
int nums = ElementNums(shapes[i]);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -47,12 +47,18 @@ const std::vector<size_t> &DatasetIteratorKernel::GetOutputSizeList() const { re
|
|||
const std::vector<size_t> &DatasetIteratorKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
|
||||
|
||||
bool DatasetIteratorKernel::Init(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
queue_name_ = GetAttr<std::string>(kernel_node, "shared_name");
|
||||
std::vector<std::vector<int>> shapes;
|
||||
std::vector<TypePtr> types;
|
||||
GetShapeAndType(kernel_node, &shapes, &types);
|
||||
|
||||
for (auto item : types) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
if (types.size() < shapes.size()) {
|
||||
MS_LOG(EXCEPTION) << "types size is less than shapes size.";
|
||||
}
|
||||
for (size_t i = 0; i < shapes.size(); i++) {
|
||||
int unit = UnitSizeInBytes(types[i]->type_id());
|
||||
int nums = ElementNums(shapes[i]);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -48,6 +48,9 @@ void GetNextProfiling::SaveProfilingData() {
|
|||
return;
|
||||
}
|
||||
for (uint32_t index = 0; index < queue_size_.size(); index++) {
|
||||
if (index > time_stamp_.size() - 1) {
|
||||
MS_LOG(EXCEPTION) << "index exceeds time_stamp_ size.";
|
||||
}
|
||||
handle << Name() << " " << time_stamp_[index].first << " " << time_stamp_[index].second << " " << queue_size_[index]
|
||||
<< std::endl;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -68,6 +68,7 @@ int ElementNums(const std::vector<int> &shape) {
|
|||
}
|
||||
|
||||
void GetShapeAndType(const CNodePtr &kernel_node, std::vector<std::vector<int>> *shapes, std::vector<TypePtr> *types) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(shapes);
|
||||
MS_EXCEPTION_IF_NULL(types);
|
||||
std::vector<std::vector<int64_t>> shapes_me =
|
||||
|
|
|
@ -87,6 +87,7 @@ class PrintGpuKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
if (AnfAlgo::HasNodeAttr("string_pos", kernel_node)) {
|
||||
string_value_ = GetAttr<std::vector<std::string>>(kernel_node, "string_value");
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -77,6 +77,7 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel {
|
|||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
nccl_data_type_ = nccl_dtype(AnfAlgo::GetInputDeviceDataType(kernel_node, 0));
|
||||
InferCommType(kernel_node);
|
||||
|
|
|
@ -49,6 +49,7 @@ class AssignGpuKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
if (!CheckParam(kernel_node)) {
|
||||
return false;
|
||||
|
@ -71,6 +72,7 @@ class AssignGpuKernel : public GpuKernel {
|
|||
|
||||
private:
|
||||
bool CheckParam(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 2) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but AssignGpuKernel needs 2 output.";
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -62,6 +62,7 @@ class BoundingBoxDecodeGpuKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 2) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but BoundingBoxDecode needs 2 inputs.";
|
||||
|
@ -89,10 +90,11 @@ class BoundingBoxDecodeGpuKernel : public GpuKernel {
|
|||
InitSizeLists();
|
||||
|
||||
const size_t coordinate_size = 4;
|
||||
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueTuple>() ||
|
||||
AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueList>()) {
|
||||
auto means = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means");
|
||||
MS_EXCEPTION_IF_NULL(means);
|
||||
if (means->isa<ValueTuple>() || means->isa<ValueList>()) {
|
||||
means_ = GetAttr<std::vector<float>>(kernel_node, "means");
|
||||
} else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<FloatImm>()) {
|
||||
} else if (means->isa<FloatImm>()) {
|
||||
float mean = GetAttr<float>(kernel_node, "means");
|
||||
for (size_t i = 0; i < coordinate_size; i++) {
|
||||
means_.emplace_back(mean);
|
||||
|
@ -101,10 +103,11 @@ class BoundingBoxDecodeGpuKernel : public GpuKernel {
|
|||
MS_LOG(EXCEPTION) << "Attribute means type is invalid.";
|
||||
}
|
||||
|
||||
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueTuple>() ||
|
||||
AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueList>()) {
|
||||
auto stds = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds");
|
||||
MS_EXCEPTION_IF_NULL(stds);
|
||||
if (stds->isa<ValueTuple>() || stds->isa<ValueList>()) {
|
||||
stds_ = GetAttr<std::vector<float>>(kernel_node, "stds");
|
||||
} else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<FloatImm>()) {
|
||||
} else if (stds->isa<FloatImm>()) {
|
||||
float std = GetAttr<float>(kernel_node, "stds");
|
||||
for (size_t i = 0; i < coordinate_size; i++) {
|
||||
stds_.emplace_back(std);
|
||||
|
|
|
@ -61,6 +61,7 @@ class BoundingBoxEncodeGpuKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 2) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but BoundingBoxEncode needs 2 inputs.";
|
||||
|
@ -88,10 +89,11 @@ class BoundingBoxEncodeGpuKernel : public GpuKernel {
|
|||
InitSizeLists();
|
||||
|
||||
const size_t coordinate_size = 4;
|
||||
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueTuple>() ||
|
||||
AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueList>()) {
|
||||
auto means = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means");
|
||||
MS_EXCEPTION_IF_NULL(means);
|
||||
if (means->isa<ValueTuple>() || means->isa<ValueList>()) {
|
||||
means_ = GetAttr<std::vector<float>>(kernel_node, "means");
|
||||
} else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<FloatImm>()) {
|
||||
} else if (means->isa<FloatImm>()) {
|
||||
float mean = GetAttr<float>(kernel_node, "means");
|
||||
for (size_t i = 0; i < coordinate_size; i++) {
|
||||
means_.emplace_back(mean);
|
||||
|
@ -99,11 +101,11 @@ class BoundingBoxEncodeGpuKernel : public GpuKernel {
|
|||
} else {
|
||||
MS_LOG(EXCEPTION) << "Attribute means type is invalid.";
|
||||
}
|
||||
|
||||
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueTuple>() ||
|
||||
AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueList>()) {
|
||||
auto stds = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds");
|
||||
MS_EXCEPTION_IF_NULL(stds);
|
||||
if (stds->isa<ValueTuple>() || stds->isa<ValueList>()) {
|
||||
stds_ = GetAttr<std::vector<float>>(kernel_node, "stds");
|
||||
} else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<FloatImm>()) {
|
||||
} else if (stds->isa<FloatImm>()) {
|
||||
float std = GetAttr<float>(kernel_node, "stds");
|
||||
for (size_t i = 0; i < coordinate_size; i++) {
|
||||
stds_.emplace_back(std);
|
||||
|
|
|
@ -55,6 +55,7 @@ class CheckValidGpuKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 2) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but CheckValid needs 2 inputs.";
|
||||
|
|
|
@ -59,6 +59,7 @@ class GpuConvertToDynamicShapeGpuKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_count != 1) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -59,6 +59,7 @@ class IOUGpuKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 2) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but IOU needs 2 inputs.";
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -89,6 +89,7 @@ class RandomCategoricalGpuKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 3) {
|
||||
|
|
|
@ -70,6 +70,7 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
uint32_t time_interval = std::chrono::system_clock::now().time_since_epoch().count();
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
|
|
|
@ -76,6 +76,7 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
|
@ -173,6 +174,7 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
|
|||
MS_LOG(EXCEPTION) << "range_max_ failed to cast";
|
||||
}
|
||||
range = static_cast<S>(range_max_);
|
||||
MS_EXCEPTION_IF_ZERO("range", range);
|
||||
return static_cast<S>(1.0f / range);
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -42,6 +42,7 @@ void AdderFusion::set_activation_type(const ActivationType activation_type) {
|
|||
|
||||
ActivationType AdderFusion::get_activation_type() const {
|
||||
auto value_ptr = GetAttr(kActivationType);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameAdderFusion, AdderFusion);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -31,9 +31,21 @@ void ArgMaxFusion::set_out_max_value(const bool out_max_value) {
|
|||
}
|
||||
void ArgMaxFusion::set_top_k(const int64_t top_k) { (void)this->AddAttr(kTopK, MakeValue(top_k)); }
|
||||
|
||||
bool ArgMaxFusion::get_keep_dims() const { return GetValue<bool>(GetAttr(kKeepDims)); }
|
||||
bool ArgMaxFusion::get_out_max_value() const { return GetValue<bool>(GetAttr(kOutMaxValue)); }
|
||||
int64_t ArgMaxFusion::get_top_k() const { return GetValue<int64_t>(GetAttr(kTopK)); }
|
||||
bool ArgMaxFusion::get_keep_dims() const {
|
||||
auto keep_dims = GetAttr(kKeepDims);
|
||||
MS_EXCEPTION_IF_NULL(keep_dims);
|
||||
return GetValue<bool>(keep_dims);
|
||||
}
|
||||
bool ArgMaxFusion::get_out_max_value() const {
|
||||
auto out_maxv = GetAttr(kOutMaxValue);
|
||||
MS_EXCEPTION_IF_NULL(out_maxv);
|
||||
return GetValue<bool>(out_maxv);
|
||||
}
|
||||
int64_t ArgMaxFusion::get_top_k() const {
|
||||
auto topk = GetAttr(kTopK);
|
||||
MS_EXCEPTION_IF_NULL(topk);
|
||||
return GetValue<int64_t>(topk);
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_C(kNameArgMaxFusion, ArgMaxFusion);
|
||||
} // namespace ops
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -29,14 +29,21 @@ void ArgMinFusion::set_keep_dims(const bool keep_dims) { (void)this->AddAttr(kKe
|
|||
void ArgMinFusion::set_out_max_value(bool out_max_value) { (void)AddAttr(kOutMaxValue, MakeValue(out_max_value)); }
|
||||
void ArgMinFusion::set_top_k(int64_t top_k) { (void)this->AddAttr(kTopK, MakeValue(top_k)); }
|
||||
|
||||
bool ArgMinFusion::get_keep_dims() const { return GetValue<bool>(GetAttr(kKeepDims)); }
|
||||
bool ArgMinFusion::get_keep_dims() const {
|
||||
auto keep_dims = GetAttr(kKeepDims);
|
||||
MS_EXCEPTION_IF_NULL(keep_dims);
|
||||
return GetValue<bool>(keep_dims);
|
||||
}
|
||||
|
||||
bool ArgMinFusion::get_out_max_value() const {
|
||||
auto value_ptr = GetAttr(kOutMaxValue);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t ArgMinFusion::get_top_k() const {
|
||||
auto value_ptr = GetAttr(kTopK);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -40,17 +40,22 @@ void AvgPoolFusion::set_activation_type(ActivationType activation_type) {
|
|||
|
||||
bool AvgPoolFusion::get_global() const {
|
||||
auto value_ptr = GetAttr(kGlobal);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
ActivationType AvgPoolFusion::get_activation_type() const {
|
||||
auto value_ptr = GetAttr(kActivationType);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
for (auto item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto op_name = primitive->name();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||
|
@ -66,6 +71,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto in_w = in_shape[3];
|
||||
|
||||
auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
|
||||
(void)CheckAndConvertUtils::CheckPositiveVector(kStride, strides, op_name);
|
||||
auto kernel_h = kernel_size[2];
|
||||
auto kernel_w = kernel_size[3];
|
||||
auto stride_h = strides[2];
|
||||
|
@ -90,8 +96,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr arg) { return arg == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
for (auto item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
return input_args[0]->BuildType();
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -52,11 +52,13 @@ void Conv2DBackpropFilterFusion::set_in_channel(const int64_t in_channel) {
|
|||
|
||||
ActivationType Conv2DBackpropFilterFusion::get_activation_type() const {
|
||||
auto value_ptr = GetAttr(kActivationType);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
int64_t Conv2DBackpropFilterFusion::get_in_channel() const {
|
||||
auto value_ptr = GetAttr(kInChannel);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameConv2DBackpropFilterFusion, Conv2DBackpropFilterFusion);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -51,11 +51,13 @@ void Conv2DBackpropInputFusion::set_activation_type(const ActivationType &activa
|
|||
}
|
||||
int64_t Conv2DBackpropInputFusion::get_in_channel() const {
|
||||
auto value_ptr = GetAttr(kInChannel);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
ActivationType Conv2DBackpropInputFusion::get_activation_type() const {
|
||||
auto value_ptr = GetAttr(kActivationType);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameConv2DBackpropInputFusion, Conv2DBackpropInputFusion);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -48,14 +48,17 @@ void Conv2DFusion::set_activation_type(const ActivationType &activation_type) {
|
|||
}
|
||||
int64_t Conv2DFusion::get_in_channel() const {
|
||||
auto value_ptr = GetAttr(kInChannel);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
std::vector<int64_t> Conv2DFusion::get_pad_list() const {
|
||||
auto value_ptr = GetAttr(kPadList);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
ActivationType Conv2DFusion::get_activation_type() const {
|
||||
auto value_ptr = GetAttr(kActivationType);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameConv2DFusion, Conv2DFusion);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -71,11 +71,13 @@ void Conv2dTransposeFusion::set_activation_type(ActivationType activation_type)
|
|||
|
||||
std::vector<int64_t> Conv2dTransposeFusion::get_output_paddings() const {
|
||||
auto value_ptr = GetAttr(kOutputPaddings);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
ActivationType Conv2dTransposeFusion::get_activation_type() const {
|
||||
auto value_ptr = GetAttr(kActivationType);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -29,6 +29,7 @@ void DivFusion::set_activation_type(const ActivationType &activation_type) {
|
|||
|
||||
ActivationType DivFusion::get_activation_type() const {
|
||||
auto value_ptr = GetAttr(kActivationType);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameDivFusion, DivFusion);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -23,6 +23,7 @@ namespace ops {
|
|||
void EmbeddingLookupFusion::set_max_norm(const float max_norm) { (void)this->AddAttr(kMaxNorm, MakeValue(max_norm)); }
|
||||
float EmbeddingLookupFusion::get_max_norm() const {
|
||||
auto value_ptr = GetAttr(kMaxNorm);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
void EmbeddingLookupFusion::Init(const float max_norm) { this->set_max_norm(max_norm); }
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -37,14 +37,17 @@ void ExpFusion::set_shift(const float shift) { (void)this->AddAttr(kShift, MakeV
|
|||
|
||||
float ExpFusion::get_base() const {
|
||||
auto value_ptr = GetAttr(kBase);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
float ExpFusion::get_scale() const {
|
||||
auto value_ptr = GetAttr(kScale);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
float ExpFusion::get_shift() const {
|
||||
auto value_ptr = GetAttr(kShift);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -22,13 +22,25 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
void FullConnection::set_has_bias(const bool has_bias) { (void)this->AddAttr(kHasBias, MakeValue(has_bias)); }
|
||||
|
||||
bool FullConnection::get_has_bias() const { return GetValue<bool>(GetAttr(kHasBias)); }
|
||||
bool FullConnection::get_has_bias() const {
|
||||
auto value_ptr = GetAttr(kHasBias);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
void FullConnection::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, MakeValue(axis)); }
|
||||
int64_t FullConnection::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
|
||||
int64_t FullConnection::get_axis() const {
|
||||
auto value_ptr = GetAttr(kAxis);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void FullConnection::set_use_axis(const bool use_axis) { (void)this->AddAttr(kUseAxis, MakeValue(use_axis)); }
|
||||
bool FullConnection::get_use_axis() const { return GetValue<bool>(GetAttr(kUseAxis)); }
|
||||
bool FullConnection::get_use_axis() const {
|
||||
auto value_ptr = GetAttr(kUseAxis);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
void FullConnection::set_activation_type(const ActivationType &activation_type) {
|
||||
int64_t swi = activation_type;
|
||||
|
@ -36,6 +48,7 @@ void FullConnection::set_activation_type(const ActivationType &activation_type)
|
|||
}
|
||||
ActivationType FullConnection::get_activation_type() const {
|
||||
auto value_ptr = GetAttr(kActivationType);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
void FullConnection::Init(const bool has_bias, const int64_t axis, const bool use_axis,
|
||||
|
@ -57,14 +70,16 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
|
|||
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input1->BuildShape())[kShape];
|
||||
auto prim_axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
|
||||
auto has_bias = GetValue<bool>(primitive->GetAttr(kHasBias));
|
||||
const int64_t input_num_bias = 3;
|
||||
const int64_t input_num = 2;
|
||||
if (has_bias) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, 3, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, input_num_bias, prim_name);
|
||||
} else {
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, 2, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, input_num, prim_name);
|
||||
}
|
||||
auto use_axis = GetValue<bool>(primitive->GetAttr(kUseAxis));
|
||||
if (use_axis && (prim_axis < 1 || prim_axis > (int64_t)input0_shape.size())) {
|
||||
MS_EXCEPTION(ValueError) << "Full Connection axis invalid";
|
||||
MS_EXCEPTION(ValueError) << "Full Connection axis is invalid";
|
||||
}
|
||||
int64_t new_k = 1;
|
||||
if (use_axis) {
|
||||
|
@ -72,7 +87,7 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
|
|||
new_k *= input0_shape[t];
|
||||
}
|
||||
if (new_k != input1_shape[1]) {
|
||||
MS_EXCEPTION(ValueError) << "Input1 size invalid";
|
||||
MS_EXCEPTION(ValueError) << "Input1 size is invalid";
|
||||
}
|
||||
} else {
|
||||
new_k = input1_shape[1];
|
||||
|
@ -80,7 +95,7 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
|
|||
if (has_bias) {
|
||||
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
if (input2_shape[0] != input1_shape[0]) {
|
||||
MS_EXCEPTION(ValueError) << "Bias size invalid";
|
||||
MS_EXCEPTION(ValueError) << "Bias size is invalid";
|
||||
}
|
||||
}
|
||||
std::vector<int64_t> out_shape = {(int64_t)input0_shape.size()};
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -34,6 +34,7 @@ void L2NormalizeFusion::set_activation_type(const ActivationType &activation_typ
|
|||
|
||||
ActivationType L2NormalizeFusion::get_activation_type() const {
|
||||
auto value_ptr = GetAttr(kActivationType);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameL2NormalizeFusion, L2NormalizeFusion);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -32,6 +32,7 @@ void LayerNormFusion::set_elementwise_affine(const bool elementwise_affine) {
|
|||
|
||||
bool LayerNormFusion::get_elementwise_affine() const {
|
||||
auto value_ptr = GetAttr(kElementwiseAffine);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameLayerNormFusion, LayerNormFusion);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -40,24 +40,28 @@ void MaxPoolFusion::set_activation_type(ActivationType activation_type) {
|
|||
|
||||
bool MaxPoolFusion::get_global() const {
|
||||
auto value_ptr = GetAttr(kGlobal);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
ActivationType MaxPoolFusion::get_activation_type() const {
|
||||
auto value_ptr = GetAttr(kActivationType);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto op_name = primitive->name();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||
if (format == NHWC) {
|
||||
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, 4, op_name);
|
||||
const int64_t in_shape_size = 4;
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, in_shape_size, op_name);
|
||||
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
|
||||
auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
|
||||
auto batch = in_shape[0];
|
||||
|
@ -84,21 +88,22 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
out_shape = {batch, out_h, out_w, channel};
|
||||
}
|
||||
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
|
||||
MS_LOG(EXCEPTION) << "Kernel size is not valid.";
|
||||
MS_LOG(EXCEPTION) << "Kernel size is invalid.";
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
return input_args[0]->BuildType();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr MaxPoolFusionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
for (auto item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -29,6 +29,7 @@ void MulFusion::set_activation_type(const ActivationType &activation_type) {
|
|||
}
|
||||
ActivationType MulFusion::get_activation_type() const {
|
||||
auto value_ptr = GetAttr(kActivationType);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
void MulFusion::Init(const ActivationType &activation_type) { this->set_activation_type(activation_type); }
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -39,10 +39,12 @@ void PadFusion::set_constant_value(const float constant_value) {
|
|||
|
||||
PaddingMode PadFusion::get_padding_mode() const {
|
||||
auto value_ptr = GetAttr(kPaddingMode);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return PaddingMode(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
float PadFusion::get_constant_value() const {
|
||||
auto value_ptr = GetAttr(kConstantValue);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -25,6 +25,7 @@ void PartialFusion::set_sub_graph_index(const int64_t sub_graph_index) {
|
|||
}
|
||||
int64_t PartialFusion::get_sub_graph_index() const {
|
||||
auto value_ptr = GetAttr(kSubGraphIndex);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNamePartialFusion, PartialFusion);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -42,6 +42,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -54,6 +55,10 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr PowFusionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInteger("PowFusion infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||
primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -36,10 +36,12 @@ void PReLUFusion::set_slope(const std::vector<float> &slope) { (void)this->AddAt
|
|||
|
||||
bool PReLUFusion::get_channel_shared() const {
|
||||
auto value_ptr = GetAttr(kChannelShared);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
std::vector<float> PReLUFusion::get_slope() const {
|
||||
auto value_ptr = GetAttr(kSlope);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<float>>(value_ptr);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNamePReLUFusion, PReLUFusion);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -41,21 +41,25 @@ void ReduceFusion::set_coeff(const float coeff) { (void)this->AddAttr(kCoeff, Ma
|
|||
|
||||
bool ReduceFusion::get_keep_dims() const {
|
||||
auto value_ptr = GetAttr(kKeepDims);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
ReduceMode ReduceFusion::get_mode() const {
|
||||
auto value_ptr = GetAttr(kMode);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return ReduceMode(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
bool ReduceFusion::get_reduce_to_end() const {
|
||||
auto value_ptr = GetAttr(kReduceToEnd);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
float ReduceFusion::get_coeff() const {
|
||||
auto value_ptr = GetAttr(kCoeff);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -32,6 +32,7 @@ void ScaleFusion::set_activation_type(const ActivationType &activation_type) {
|
|||
|
||||
ActivationType ScaleFusion::get_activation_type() const {
|
||||
auto value_ptr = GetAttr(kActivationType);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameScaleFusion, ScaleFusion);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -26,6 +26,7 @@ void SliceFusion::set_axes(const std::vector<int64_t> &axes) { (void)this->AddAt
|
|||
|
||||
std::vector<int64_t> SliceFusion::get_axes() const {
|
||||
auto value_ptr = GetAttr(kAxes);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -39,6 +40,8 @@ AbstractBasePtr SliceFusionInfer(const abstract::AnalysisEnginePtr &, const Prim
|
|||
auto size_v = input_args[2]->BuildValue();
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(x_type);
|
||||
MS_EXCEPTION_IF_NULL(begin_v);
|
||||
MS_EXCEPTION_IF_NULL(size_v);
|
||||
auto tensor_type = x_type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
auto data_type = tensor_type->element();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -29,6 +29,7 @@ void SubFusion::set_activation_type(const ActivationType &activation_type) {
|
|||
|
||||
ActivationType SubFusion::get_activation_type() const {
|
||||
auto value_ptr = GetAttr(kActivationType);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameSubFusion, SubFusion);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -26,6 +26,7 @@ void TileFusion::set_dims(const std::vector<int64_t> &dims) { (void)this->AddAtt
|
|||
|
||||
std::vector<int64_t> TileFusion::get_dims() const {
|
||||
auto value_ptr = GetAttr(kDims);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameTileFusion, TileFusion);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -32,11 +32,13 @@ void TopKFusion::set_largest(const int64_t largest) { (void)this->AddAttr(kLarge
|
|||
|
||||
int64_t TopKFusion::get_axis() const {
|
||||
auto value_ptr = GetAttr(kAxis);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t TopKFusion::get_largest() const {
|
||||
auto value_ptr = GetAttr(kLargest);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameTopKFusion, TopKFusion);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -38,6 +38,7 @@ void ActivationGrad::set_activation_type(const ActivationType &type) {
|
|||
|
||||
ActivationType ActivationGrad::get_activation_type() const {
|
||||
auto value_ptr = GetAttr(kActivationType);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
|
@ -45,6 +46,7 @@ void ActivationGrad::set_alpha(const float alpha) { (void)this->AddAttr(kAlpha,
|
|||
|
||||
float ActivationGrad::get_alpha() const {
|
||||
auto value_ptr = GetAttr(kAlpha);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameActivationGrad, ActivationGrad);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -31,6 +31,7 @@ void BatchNormGrad::set_epsilon(const float epsilon) { (void)this->AddAttr(kEpsi
|
|||
|
||||
float BatchNormGrad::get_epsilon() const {
|
||||
auto value_ptr = this->GetAttr(kEpsilon);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -38,15 +39,19 @@ void BatchNormGrad::set_is_training(const bool is_training) { this->AddAttr(kIsT
|
|||
|
||||
bool BatchNormGrad::get_is_training() const {
|
||||
auto value_ptr = this->GetAttr(kIsTraining);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
AbstractBasePtr BatchNormGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
MS_EXCEPTION_IF_NULL(input_args[1]);
|
||||
MS_EXCEPTION_IF_NULL(input_args[2]);
|
||||
MS_EXCEPTION_IF_NULL(input_args[3]);
|
||||
for (auto item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
const int64_t input_num = 5;
|
||||
CheckAndConvertUtils::CheckInteger("BatchNormGrad infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||
primitive->name());
|
||||
auto y_backprop_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::Check("BatchNorm y_backprop_shape", y_backprop_shape, kEqual, "BatchNorm x_shape", x_shape);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -38,9 +38,6 @@ abstract::ShapePtr BinaryCrossEntroyGradInferShape(const PrimitivePtr &primitive
|
|||
}
|
||||
|
||||
TypePtr BinaryCrossEntroyGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x_shape", input_args[0]->BuildType());
|
||||
|
@ -67,6 +64,13 @@ Reduction BinaryCrossEntropyGrad::get_reduction() const {
|
|||
|
||||
AbstractBasePtr BinaryCrossEntropyGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
const int64_t input_num = 4;
|
||||
CheckAndConvertUtils::CheckInteger("BinaryCrossEntropyGrad infer", SizeToLong(input_args.size()), kGreaterEqual,
|
||||
input_num, primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(BinaryCrossEntroyGradInferType(primitive, input_args),
|
||||
BinaryCrossEntroyGradInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -30,6 +30,7 @@ void BNGrad::set_eps(const float eps) { (void)this->AddAttr(kEps, MakeValue(eps)
|
|||
|
||||
float BNGrad::get_eps() const {
|
||||
auto value_ptr = this->GetAttr(kEps);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -37,6 +38,7 @@ void BNGrad::set_momentum(const float momentum) { (void)this->AddAttr(kMomentum,
|
|||
|
||||
float BNGrad::get_momentum() const {
|
||||
auto value_ptr = this->GetAttr(kMomentum);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameBNGrad, BNGrad);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -71,6 +71,7 @@ void Conv2DBackpropFilter::set_out_channel(const int64_t out_channel) {
|
|||
|
||||
int64_t Conv2DBackpropFilter::get_out_channel() const {
|
||||
auto value_ptr = GetAttr(kOutChannel);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -80,6 +81,7 @@ void Conv2DBackpropFilter::set_kernel_size(const std::vector<int64_t> &kernel_si
|
|||
|
||||
std::vector<int64_t> Conv2DBackpropFilter::get_kernel_size() const {
|
||||
auto value_ptr = GetAttr(kKernelSize);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -90,6 +92,7 @@ void Conv2DBackpropFilter::set_pad_mode(const PadMode &pad_mode) {
|
|||
|
||||
PadMode Conv2DBackpropFilter::get_pad_mode() const {
|
||||
auto value_ptr = GetAttr(kPadMode);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return PadMode(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
|
@ -99,6 +102,7 @@ void Conv2DBackpropFilter::set_pad_list(const std::vector<int64_t> &pad_list) {
|
|||
|
||||
std::vector<int64_t> Conv2DBackpropFilter::get_pad_list() const {
|
||||
auto value_ptr = GetAttr(kPadList);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -106,6 +110,7 @@ void Conv2DBackpropFilter::set_mode(const int64_t mode) { (void)this->AddAttr(kM
|
|||
|
||||
int64_t Conv2DBackpropFilter::get_mode() const {
|
||||
auto value_ptr = GetAttr(kMode);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -113,6 +118,7 @@ void Conv2DBackpropFilter::set_stride(const std::vector<int64_t> &stride) { this
|
|||
|
||||
std::vector<int64_t> Conv2DBackpropFilter::get_stride() const {
|
||||
auto value_ptr = GetAttr(kStride);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -122,6 +128,7 @@ void Conv2DBackpropFilter::set_dilation(const std::vector<int64_t> &dilation) {
|
|||
|
||||
std::vector<int64_t> Conv2DBackpropFilter::get_dilation() const {
|
||||
auto value_ptr = GetAttr(kDilation);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -129,6 +136,7 @@ void Conv2DBackpropFilter::set_group(const int64_t group) { (void)this->AddAttr(
|
|||
|
||||
int64_t Conv2DBackpropFilter::get_group() const {
|
||||
auto value_ptr = GetAttr(kGroup);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -139,6 +147,7 @@ void Conv2DBackpropFilter::set_format(const Format &format) {
|
|||
|
||||
Format Conv2DBackpropFilter::get_format() const {
|
||||
auto value_ptr = GetAttr(kFormat);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return Format(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -172,51 +172,61 @@ void Conv2DBackpropInput::set_pad_list(const std::vector<int64_t> &pad_list) {
|
|||
|
||||
int64_t Conv2DBackpropInput::get_out_channel() const {
|
||||
auto value_ptr = GetAttr(kOutChannel);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2DBackpropInput::get_kernel_size() const {
|
||||
auto value_ptr = GetAttr(kKernelSize);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2DBackpropInput::get_stride() const {
|
||||
auto value_ptr = GetAttr(kStride);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2DBackpropInput::get_dilation() const {
|
||||
auto value_ptr = GetAttr(kDilation);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
PadMode Conv2DBackpropInput::get_pad_mode() const {
|
||||
auto value_ptr = GetAttr(kPadMode);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return PadMode(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2DBackpropInput::get_pad() const {
|
||||
auto value_ptr = GetAttr(kPad);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t Conv2DBackpropInput::get_mode() const {
|
||||
auto value_ptr = GetAttr(kMode);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t Conv2DBackpropInput::get_group() const {
|
||||
auto value_ptr = GetAttr(kGroup);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
Format Conv2DBackpropInput::get_format() const {
|
||||
auto value_ptr = GetAttr(kFormat);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return Format(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2DBackpropInput::get_pad_list() const {
|
||||
auto value_ptr = GetAttr(kPadList);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropInput, prim::kPrimConv2DBackpropInput, Conv2DBackpropInputInfer, nullptr,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -45,6 +45,7 @@ void DeConv2DGradFilter::set_in_channel(const int64_t in_channel) {
|
|||
|
||||
int64_t DeConv2DGradFilter::get_in_channel() const {
|
||||
auto value_ptr = GetAttr(kInChannel);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -54,6 +55,7 @@ void DeConv2DGradFilter::set_out_channel(const int64_t out_channel) {
|
|||
|
||||
int64_t DeConv2DGradFilter::get_out_channel() const {
|
||||
auto value_ptr = GetAttr(kOutChannel);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -63,6 +65,7 @@ void DeConv2DGradFilter::set_kernel_size(const std::vector<int64_t> &kernel_size
|
|||
|
||||
std::vector<int64_t> DeConv2DGradFilter::get_kernel_size() const {
|
||||
auto value_ptr = GetAttr(kKernelSize);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -73,6 +76,7 @@ void DeConv2DGradFilter::set_pad_mode(const PadMode &pad_mode) {
|
|||
|
||||
PadMode DeConv2DGradFilter::get_pad_mode() const {
|
||||
auto value_ptr = GetAttr(kPadMode);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return PadMode(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
|
@ -82,6 +86,7 @@ void DeConv2DGradFilter::set_pad_list(const std::vector<int64_t> &pad_list) {
|
|||
|
||||
std::vector<int64_t> DeConv2DGradFilter::get_pad_list() const {
|
||||
auto value_ptr = GetAttr(kPadList);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -91,6 +96,7 @@ void DeConv2DGradFilter::set_stride(const std::vector<int64_t> &stride) {
|
|||
|
||||
std::vector<int64_t> DeConv2DGradFilter::get_stride() const {
|
||||
auto value_ptr = GetAttr(kStride);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -100,6 +106,7 @@ void DeConv2DGradFilter::set_dilation(const std::vector<int64_t> &dilation) {
|
|||
|
||||
std::vector<int64_t> DeConv2DGradFilter::get_dilation() const {
|
||||
auto value_ptr = GetAttr(kDilation);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -107,6 +114,7 @@ void DeConv2DGradFilter::set_group(const int64_t group) { (void)this->AddAttr(kG
|
|||
|
||||
int64_t DeConv2DGradFilter::get_group() const {
|
||||
auto value_ptr = GetAttr(kGroup);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -117,6 +125,7 @@ void DeConv2DGradFilter::set_format(const Format &format) {
|
|||
|
||||
Format DeConv2DGradFilter::get_format() const {
|
||||
auto value_ptr = GetAttr(kFormat);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return Format(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
|
@ -127,6 +136,7 @@ void DeConv2DGradFilter::set_activation_type(const ActivationType &activation_ty
|
|||
|
||||
ActivationType DeConv2DGradFilter::get_activation_type() const {
|
||||
auto value_ptr = GetAttr(kActivationType);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
|
@ -134,6 +144,7 @@ void DeConv2DGradFilter::set_has_bias(const bool has_bias) { (void)this->AddAttr
|
|||
|
||||
bool DeConv2DGradFilter::get_has_bias() const {
|
||||
auto value_ptr = GetAttr(kHasBias);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameDeConv2DGradFilter, DeConv2DGradFilter);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -28,6 +28,7 @@ void DropoutGrad::set_keep_prob(const float keep_prob) {
|
|||
|
||||
float DropoutGrad::get_keep_prob() const {
|
||||
auto value_ptr = GetAttr(kKeepProb);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -55,6 +56,13 @@ TypePtr DropoutGradInferType(const PrimitivePtr &prim, const std::vector<Abstrac
|
|||
|
||||
AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
for (auto item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInteger("DropoutGrad infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||
primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(DropoutGradInferType(primitive, input_args),
|
||||
DropoutGradInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -22,7 +22,8 @@ AbstractBasePtr FlattenGradInfer(const abstract::AnalysisEnginePtr &, const Prim
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -46,6 +46,7 @@ void GroupConv2DGradInput::set_in_channel(const int64_t &in_channel) {
|
|||
|
||||
int64_t GroupConv2DGradInput::get_in_channel() const {
|
||||
auto value_ptr = GetAttr(kInChannel);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -55,6 +56,7 @@ void GroupConv2DGradInput::set_out_channel(const int64_t &out_channel) {
|
|||
|
||||
int64_t GroupConv2DGradInput::get_out_channel() const {
|
||||
auto value_ptr = GetAttr(kOutChannel);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -64,6 +66,7 @@ void GroupConv2DGradInput::set_kernel_size(const std::vector<int64_t> &kernel_si
|
|||
|
||||
std::vector<int64_t> GroupConv2DGradInput::get_kernel_size() const {
|
||||
auto value_ptr = GetAttr(kKernelSize);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -74,6 +77,7 @@ void GroupConv2DGradInput::set_pad_mode(const PadMode &pad_mode) {
|
|||
|
||||
PadMode GroupConv2DGradInput::get_pad_mode() const {
|
||||
auto value_ptr = GetAttr(kPadMode);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return PadMode(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
|
@ -83,6 +87,7 @@ void GroupConv2DGradInput::set_pad_list(const std::vector<int64_t> &pad_list) {
|
|||
|
||||
std::vector<int64_t> GroupConv2DGradInput::get_pad_list() const {
|
||||
auto value_ptr = GetAttr(kPadList);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -90,6 +95,7 @@ void GroupConv2DGradInput::set_stride(const std::vector<int64_t> &stride) { this
|
|||
|
||||
std::vector<int64_t> GroupConv2DGradInput::get_stride() const {
|
||||
auto value_ptr = GetAttr(kStride);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -99,6 +105,7 @@ void GroupConv2DGradInput::set_dilation(const std::vector<int64_t> &dilation) {
|
|||
|
||||
std::vector<int64_t> GroupConv2DGradInput::get_dilation() const {
|
||||
auto value_ptr = GetAttr(kDilation);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -106,6 +113,7 @@ void GroupConv2DGradInput::set_group(const int64_t &group) { (void)this->AddAttr
|
|||
|
||||
int64_t GroupConv2DGradInput::get_group() const {
|
||||
auto value_ptr = GetAttr(kGroup);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -114,7 +122,9 @@ void GroupConv2DGradInput::set_input_shape(const std::vector<int64_t> &input_sha
|
|||
}
|
||||
|
||||
std::vector<int64_t> GroupConv2DGradInput::get_input_shape() const {
|
||||
return GetValue<std::vector<int64_t>>(GetAttr(kInputShape));
|
||||
auto value_ptr = GetAttr(kInputShape);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
void GroupConv2DGradInput::set_format(const Format &format) {
|
||||
|
@ -124,6 +134,7 @@ void GroupConv2DGradInput::set_format(const Format &format) {
|
|||
|
||||
Format GroupConv2DGradInput::get_format() const {
|
||||
auto value_ptr = GetAttr(kFormat);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return Format(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
|
@ -134,6 +145,7 @@ void GroupConv2DGradInput::set_activation_type(const ActivationType &activation_
|
|||
|
||||
ActivationType GroupConv2DGradInput::get_activation_type() const {
|
||||
auto value_ptr = GetAttr(kActivationType);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
|
@ -141,6 +153,7 @@ void GroupConv2DGradInput::set_has_bias(const bool has_bias) { (void)this->AddAt
|
|||
|
||||
bool GroupConv2DGradInput::get_has_bias() const {
|
||||
auto value_ptr = GetAttr(kHasBias);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
AbstractBasePtr GroupConv2DGradInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
@ -152,11 +165,16 @@ AbstractBasePtr GroupConv2DGradInputInfer(const abstract::AnalysisEnginePtr &, c
|
|||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
|
||||
// Infer shape
|
||||
auto shape = GetValue<std::vector<int64_t>>(primitive->GetAttr(kInputShape));
|
||||
auto shape_ptr = primitive->GetAttr(kInputShape);
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
auto shape = GetValue<std::vector<int64_t>>(shape_ptr);
|
||||
|
||||
// Infer type
|
||||
auto type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
|
||||
auto type_ptr = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||
auto type_tensor_ptr = type_ptr->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(type_tensor_ptr);
|
||||
auto type = type_tensor_ptr->element();
|
||||
return std::make_shared<abstract::AbstractTensor>(type, shape);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameGroupConv2DGradInput, GroupConv2DGradInput);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -48,10 +48,12 @@ void LayerNormGrad::set_begin_params_axis(const int64_t begin_params_axis) {
|
|||
}
|
||||
int64_t LayerNormGrad::get_begin_norm_axis() const {
|
||||
auto value_ptr = this->GetAttr(kBeginNormAxis);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
int64_t LayerNormGrad::get_begin_params_axis() const {
|
||||
auto value_ptr = this->GetAttr(kBeginParamsAxis);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(LayerNormGrad, prim::kPrimLayerNormGrad, LayerNormGradInfer, nullptr, true);
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue