forked from mindspore-Ecosystem/mindspore
commit
1b3efdc70a
|
@ -0,0 +1,79 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/smooth_l1_loss_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void SmoothL1LossCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
beta_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "beta");
|
||||
CheckParam(kernel_node);
|
||||
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
||||
std::vector<uint64_t> x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
for (const uint64_t &d : x_shape) {
|
||||
tensor_size_ *= d;
|
||||
}
|
||||
}
|
||||
|
||||
bool SmoothL1LossCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (dtype_ == kNumberTypeFloat16) {
|
||||
LaunchKernel<float16>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat32) {
|
||||
LaunchKernel<float>(inputs, outputs);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SmoothL1LossCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto predict_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto target_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto result_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
T zero = (T)0.0;
|
||||
T half = (T)0.5;
|
||||
T beta = (T)beta_;
|
||||
for (uint64_t i = 0; i < tensor_size_; ++i) {
|
||||
T diff = predict_addr[i] - target_addr[i];
|
||||
if (diff < zero) {
|
||||
diff = -diff;
|
||||
}
|
||||
if (diff < beta) {
|
||||
result_addr[i] = half * diff * diff / beta;
|
||||
} else {
|
||||
result_addr[i] = diff - (half * beta);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SmoothL1LossCPUKernel::CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 2) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but SmoothL1LossCPUKernel needs 2 input.";
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SmoothL1LossCPUKernel needs 1 output.";
|
||||
}
|
||||
if (beta_ == 0.0) {
|
||||
MS_LOG(EXCEPTION) << "Attr beta can not be zero.";
|
||||
}
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SMOOTH_L1_LOSS_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SMOOTH_L1_LOSS_CPU_KERNEL_H_
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class SmoothL1LossCPUKernel : public CPUKernel {
|
||||
public:
|
||||
SmoothL1LossCPUKernel() = default;
|
||||
~SmoothL1LossCPUKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
private:
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
float beta_ = 1.0;
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
uint64_t tensor_size_ = 1;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
SmoothL1Loss,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
SmoothL1LossCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
SmoothL1Loss,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SmoothL1LossCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SMOOTH_L1_LOSS_CPU_KERNEL_H_
|
|
@ -0,0 +1,77 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/smooth_l1_loss_grad_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void SmoothL1LossGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
beta_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "beta");
|
||||
CheckParam(kernel_node);
|
||||
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
||||
std::vector<uint64_t> x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
for (const uint64_t &d : x_shape) {
|
||||
tensor_size_ *= d;
|
||||
}
|
||||
}
|
||||
|
||||
bool SmoothL1LossGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (dtype_ == kNumberTypeFloat16) {
|
||||
LaunchKernel<float16>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat32) {
|
||||
LaunchKernel<float>(inputs, outputs);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SmoothL1LossGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto predict_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto target_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto dloss_addr = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
auto result_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
T beta = (T)beta_;
|
||||
for (uint64_t i = 0; i < tensor_size_; ++i) {
|
||||
T diff = predict_addr[i] - target_addr[i];
|
||||
if (diff > beta) {
|
||||
result_addr[i] = dloss_addr[i];
|
||||
} else if (diff < -beta) {
|
||||
result_addr[i] = -dloss_addr[i];
|
||||
} else {
|
||||
result_addr[i] = (diff / beta) * dloss_addr[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SmoothL1LossGradCPUKernel::CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 3) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but SmoothL1LossGradCPUKernel needs 3 input.";
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SmoothL1LossGradCPUKernel needs 1 output.";
|
||||
}
|
||||
if (beta_ == 0.0) {
|
||||
MS_LOG(EXCEPTION) << "Attr beta can not be zero.";
|
||||
}
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SMOOTH_L1_LOSS_GRAD_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SMOOTH_L1_LOSS_GRAD_CPU_KERNEL_H_
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class SmoothL1LossGradCPUKernel : public CPUKernel {
|
||||
public:
|
||||
SmoothL1LossGradCPUKernel() = default;
|
||||
~SmoothL1LossGradCPUKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
private:
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
float beta_ = 1.0;
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
uint64_t tensor_size_ = 1;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(SmoothL1LossGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
SmoothL1LossGradCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(SmoothL1LossGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
SmoothL1LossGradCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SMOOTH_L1_LOSS_GRAD_CPU_KERNEL_H_
|
|
@ -0,0 +1,103 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/tile_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void TileCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
CheckParam(kernel_node);
|
||||
x_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
y_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
multiples_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "multiples");
|
||||
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
||||
}
|
||||
|
||||
bool TileCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (dtype_ == kNumberTypeInt32) {
|
||||
LaunchKernel<int>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat32) {
|
||||
LaunchKernel<float>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeInt64) {
|
||||
LaunchKernel<int64_t>(inputs, outputs);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TileRecTask(T *x, T *y, size_t dim, size_t *offset, std::vector<size_t> *pos, const std::vector<int> &multiples,
|
||||
const std::vector<size_t> &cargo_x, const std::vector<size_t> &cargo_y,
|
||||
const std::vector<size_t> &x_shape) {
|
||||
if (dim == x_shape.size()) {
|
||||
return;
|
||||
}
|
||||
for (size_t i = 0; i < x_shape[dim]; ++i) {
|
||||
(*pos)[dim] = i;
|
||||
if (dim == x_shape.size() - 1) {
|
||||
size_t x_offset = 0;
|
||||
for (size_t j = 0; j < (*pos).size(); ++j) {
|
||||
x_offset += (*pos)[j] * cargo_x[j];
|
||||
}
|
||||
memcpy(y + *offset, x + x_offset, sizeof(T));
|
||||
*offset += 1;
|
||||
continue;
|
||||
}
|
||||
TileRecTask(x, y, dim + 1, offset, pos, multiples, cargo_x, cargo_y, x_shape);
|
||||
}
|
||||
for (int m = 0; m < multiples[dim] - 1; ++m) {
|
||||
size_t y_offset = *offset - cargo_y[dim];
|
||||
memcpy(y + *offset, y + y_offset, cargo_y[dim] * sizeof(T));
|
||||
*offset += cargo_y[dim];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TileCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
|
||||
auto x_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto y_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
size_t ones = multiples_.size() - x_shape_.size();
|
||||
if (ones > 0) {
|
||||
for (size_t i = 0; i < ones; ++i) {
|
||||
x_shape_.insert(x_shape_.begin(), 1);
|
||||
}
|
||||
}
|
||||
int d = multiples_.size();
|
||||
std::vector<size_t> pos(d, 0);
|
||||
std::vector<size_t> cargo_x(d, 1);
|
||||
std::vector<size_t> cargo_y = x_shape_;
|
||||
for (int i = d - 2; i >= 0; --i) {
|
||||
cargo_x[i] = x_shape_[i + 1] * cargo_x[i];
|
||||
cargo_y[i] *= cargo_y[i + 1] * multiples_[i + 1];
|
||||
}
|
||||
size_t offset = 0;
|
||||
TileRecTask<T>(x_addr, y_addr, 0, &offset, &pos, multiples_, cargo_x, cargo_y, x_shape_);
|
||||
}
|
||||
|
||||
void TileCPUKernel::CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but TileCPUKernel needs 1 input.";
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but TileCPUKernel needs 1 output.";
|
||||
}
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class TileCPUKernel : public CPUKernel {
|
||||
public:
|
||||
TileCPUKernel() = default;
|
||||
~TileCPUKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
private:
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
std::vector<size_t> x_shape_;
|
||||
std::vector<size_t> y_shape_;
|
||||
std::vector<int> multiples_;
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), TileCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), TileCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), TileCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), TileCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
|
@ -4048,7 +4048,7 @@ class GatherD(PrimitiveWithInfer):
|
|||
- **x** (Tensor) - The source tensor.
|
||||
- **dim** (int) - The axis along which to index. It must be int32. Only constant value is allowed.
|
||||
- **index** (Tensor) - The indices of elements to gather. It can be one of the following data types:
|
||||
int32, int64.
|
||||
int32, int64.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.composite import GradOperation
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, sigma=1.0):
|
||||
super(Net, self).__init__()
|
||||
self.SmoothL1Loss = P.SmoothL1Loss(sigma)
|
||||
|
||||
def construct(self, pred, gt):
|
||||
return self.SmoothL1Loss(pred, gt)
|
||||
|
||||
|
||||
class Grad(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(Grad, self).__init__()
|
||||
self.grad = GradOperation(get_all=True, sens_param=True)
|
||||
self.network = network
|
||||
|
||||
def construct(self, pred, gt, dout):
|
||||
return self.grad(self.network)(pred, gt, dout)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_net():
|
||||
pred = np.random.randn(2, 4).astype(np.float32)
|
||||
gt = np.random.randn(2, 4).astype(np.float32)
|
||||
dout = np.random.randn(2, 4).astype(np.float32)
|
||||
smooth_l1_loss_grad = Grad(Net())
|
||||
output = smooth_l1_loss_grad(Tensor(pred), Tensor(gt), Tensor(dout))
|
||||
print("------------- input ---------------")
|
||||
print("predict:\n", pred)
|
||||
print("grount truth:\n", gt)
|
||||
print("dout:\n", dout)
|
||||
print("------------- output ---------------")
|
||||
print("predict grad:\n", output[0].asnumpy())
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, sigma=1.0):
|
||||
super(Net, self).__init__()
|
||||
self.SmoothL1Loss = P.SmoothL1Loss(sigma)
|
||||
|
||||
def construct(self, pred, gt):
|
||||
return self.SmoothL1Loss(pred, gt)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_net():
|
||||
pred = np.random.randn(2, 4).astype(np.float32)
|
||||
gt = np.random.randn(2, 4).astype(np.float32)
|
||||
smooth_l1_loss = Net()
|
||||
loss = smooth_l1_loss(Tensor(pred), Tensor(gt))
|
||||
print("------------- input ---------------")
|
||||
print("predict:\n", pred)
|
||||
print("grount truth:\n", gt)
|
||||
print("------------- output ---------------")
|
||||
print("loss:\n", loss.asnumpy())
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.tile = P.Tile()
|
||||
|
||||
def construct(self, x):
|
||||
return self.tile(x, (1, 4))
|
||||
|
||||
|
||||
arr_x = np.array([[0], [1], [2], [3]]).astype(np.int32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_net():
|
||||
tile = Net()
|
||||
print(arr_x)
|
||||
output = tile(Tensor(arr_x))
|
||||
print(output.asnumpy())
|
Loading…
Reference in New Issue