forked from mindspore-Ecosystem/mindspore
!7786 Add SampledSoftmaxLoss GPU Kernel
Merge pull request !7786 from JonathanY/ops_oct
This commit is contained in:
commit
deb17b36c1
|
@ -30,7 +30,8 @@ namespace kernel {
|
||||||
template <typename T, typename S>
|
template <typename T, typename S>
|
||||||
class UniformSamplerGpuKernel : public GpuKernel {
|
class UniformSamplerGpuKernel : public GpuKernel {
|
||||||
public:
|
public:
|
||||||
UniformSamplerGpuKernel() : num_true_(0), num_sampled_(0), unique_(false), range_max_(0), input_size_(0) {}
|
UniformSamplerGpuKernel()
|
||||||
|
: num_true_(0), num_sampled_(0), unique_(false), range_max_(0), input_size_(0), remove_accidental_hits_(false) {}
|
||||||
~UniformSamplerGpuKernel() override = default;
|
~UniformSamplerGpuKernel() override = default;
|
||||||
|
|
||||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||||
|
@ -43,6 +44,16 @@ class UniformSamplerGpuKernel : public GpuKernel {
|
||||||
T *sampled_candidates = GetDeviceAddress<T>(outputs, 0);
|
T *sampled_candidates = GetDeviceAddress<T>(outputs, 0);
|
||||||
S *true_expected_count = GetDeviceAddress<S>(outputs, 1);
|
S *true_expected_count = GetDeviceAddress<S>(outputs, 1);
|
||||||
S *sampled_expected_count = GetDeviceAddress<S>(outputs, 2);
|
S *sampled_expected_count = GetDeviceAddress<S>(outputs, 2);
|
||||||
|
if (remove_accidental_hits_) {
|
||||||
|
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||||
|
array_input_ = std::vector<T>(input_size_, 0);
|
||||||
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(&array_input_[0], input, input_size_ * sizeof(T),
|
||||||
|
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"cudaMemcpyAsync sampled_candidates failed");
|
||||||
|
for (const auto item : array_input_) {
|
||||||
|
set_input_.insert(item);
|
||||||
|
}
|
||||||
|
}
|
||||||
int counter = Sampling();
|
int counter = Sampling();
|
||||||
float prob = Probability();
|
float prob = Probability();
|
||||||
size_t sampled_candidates_size = num_sampled_ * sizeof(T);
|
size_t sampled_candidates_size = num_sampled_ * sizeof(T);
|
||||||
|
@ -72,6 +83,7 @@ class UniformSamplerGpuKernel : public GpuKernel {
|
||||||
unique_ = GetAttr<bool>(kernel_node, "unique");
|
unique_ = GetAttr<bool>(kernel_node, "unique");
|
||||||
range_max_ = GetAttr<int>(kernel_node, "range_max");
|
range_max_ = GetAttr<int>(kernel_node, "range_max");
|
||||||
int seed = GetAttr<int>(kernel_node, "seed");
|
int seed = GetAttr<int>(kernel_node, "seed");
|
||||||
|
remove_accidental_hits_ = GetAttr<bool>(kernel_node, "remove_accidental_hits");
|
||||||
if (seed == 0) seed = time(NULL);
|
if (seed == 0) seed = time(NULL);
|
||||||
generator_.seed(seed);
|
generator_.seed(seed);
|
||||||
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||||
|
@ -80,6 +92,9 @@ class UniformSamplerGpuKernel : public GpuKernel {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
input_size_ = input_shape[0] * input_shape[1];
|
input_size_ = input_shape[0] * input_shape[1];
|
||||||
|
if (num_sampled_ * num_true_ + static_cast<int>(input_size_) > range_max_ * num_true_) {
|
||||||
|
remove_accidental_hits_ = false;
|
||||||
|
}
|
||||||
InitSizeLists();
|
InitSizeLists();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -105,7 +120,8 @@ class UniformSamplerGpuKernel : public GpuKernel {
|
||||||
while (picked < num_sampled_) {
|
while (picked < num_sampled_) {
|
||||||
tmp = distribution(generator_);
|
tmp = distribution(generator_);
|
||||||
counter++;
|
counter++;
|
||||||
if (set_container.find(tmp) == set_container.end()) {
|
if ((set_container.find(tmp) == set_container.end()) &&
|
||||||
|
((!remove_accidental_hits_) || set_input_.find(tmp) == set_input_.end())) {
|
||||||
set_container.insert(tmp);
|
set_container.insert(tmp);
|
||||||
sampled_candidates_.push_back(tmp);
|
sampled_candidates_.push_back(tmp);
|
||||||
picked++;
|
picked++;
|
||||||
|
@ -133,6 +149,9 @@ class UniformSamplerGpuKernel : public GpuKernel {
|
||||||
bool unique_;
|
bool unique_;
|
||||||
int range_max_;
|
int range_max_;
|
||||||
size_t input_size_;
|
size_t input_size_;
|
||||||
|
bool remove_accidental_hits_;
|
||||||
|
std::vector<T> array_input_;
|
||||||
|
std::set<int> set_input_;
|
||||||
std::default_random_engine generator_;
|
std::default_random_engine generator_;
|
||||||
std::vector<int> sampled_candidates_;
|
std::vector<int> sampled_candidates_;
|
||||||
std::vector<size_t> input_size_list_;
|
std::vector<size_t> input_size_list_;
|
||||||
|
|
|
@ -20,8 +20,9 @@ It shows how well the model works on a dataset and the optimization target which
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .loss import L1Loss, MSELoss, SmoothL1Loss, \
|
from .loss import L1Loss, MSELoss, SmoothL1Loss, \
|
||||||
SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss
|
SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \
|
||||||
|
SampledSoftmaxLoss
|
||||||
|
|
||||||
__all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss',
|
__all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss',
|
||||||
'SoftmaxCrossEntropyWithLogits', 'BCELoss',
|
'SoftmaxCrossEntropyWithLogits', 'BCELoss',
|
||||||
'CosineEmbeddingLoss']
|
'CosineEmbeddingLoss', 'SampledSoftmaxLoss']
|
||||||
|
|
|
@ -263,6 +263,186 @@ class SoftmaxCrossEntropyWithLogits(_Loss):
|
||||||
return self.get_loss(x)
|
return self.get_loss(x)
|
||||||
|
|
||||||
|
|
||||||
|
class SampledSoftmaxLoss(_Loss):
|
||||||
|
r"""
|
||||||
|
Computes the sampled softmax training loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_sampled (int): The number of classes to randomly sample per batch.
|
||||||
|
num_classes (int): The number of possible classes.
|
||||||
|
num_true (int): The number of target classes per training example.
|
||||||
|
sampled_values (Tuple): Tuple of (`sampled_candidates`, `true_expected_count`,
|
||||||
|
`sampled_expected_count`) returned by a `*_candidate_sampler` function.
|
||||||
|
Default to None, `log_uniform_candidate_sampler` is applied.
|
||||||
|
remove_accidental_hits (bool): Whether to remove "accidental hits"
|
||||||
|
where a sampled class equals one of the target classes. Default is True.
|
||||||
|
seed (int): Random seed for candidate sampling. Default: 0
|
||||||
|
reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none".
|
||||||
|
If "none", do not perform reduction. Default: "None".
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **weights** (Tensor) - Tensor of shape (C, dim).
|
||||||
|
- **bias** (Tensor) - Tensor of shape (C). The class biases.
|
||||||
|
- **labels** (Tensor) - Tensor of shape (N, num_true), type `int64`. The
|
||||||
|
target classes.
|
||||||
|
- **inputs** (Tensor) - Tensor of shape (N, dim). The forward activations of
|
||||||
|
the input network.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, a tensor of shape (N) with the per-example sampled softmax losses.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_sampled, num_classes, num_true=1,
|
||||||
|
sampled_values=None, remove_accidental_hits=True, seed=0,
|
||||||
|
reduction='none'):
|
||||||
|
super(SampledSoftmaxLoss, self).__init__()
|
||||||
|
self.num_sampled = num_sampled
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.num_true = num_true
|
||||||
|
self.sampled_values = sampled_values
|
||||||
|
self.remove_accidental_hits = remove_accidental_hits
|
||||||
|
self.seed = seed
|
||||||
|
self.sampler = P.UniformSampler(
|
||||||
|
num_true,
|
||||||
|
num_sampled,
|
||||||
|
True,
|
||||||
|
num_classes,
|
||||||
|
seed,
|
||||||
|
remove_accidental_hits)
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.shape = P.Shape()
|
||||||
|
self.exp = P.Exp()
|
||||||
|
self.log = P.Log()
|
||||||
|
self.slice_op = P.Slice()
|
||||||
|
self.matmul = P.MatMul(False, True)
|
||||||
|
self.gather_v2 = P.GatherV2()
|
||||||
|
self.reduce_max_true = P.ReduceMax(True)
|
||||||
|
self.reduce_sum = P.ReduceSum()
|
||||||
|
self.reduce_sum_true = P.ReduceSum(True)
|
||||||
|
self.concat_dim0 = P.Concat(0)
|
||||||
|
self.concat_dim1 = P.Concat(1)
|
||||||
|
self.ones_like = P.OnesLike()
|
||||||
|
self.zeros_like = P.ZerosLike()
|
||||||
|
self.mul = P.Mul()
|
||||||
|
self.expand_dims = P.ExpandDims()
|
||||||
|
|
||||||
|
def construct(self, weights, biases, labels, inputs):
|
||||||
|
logits, labels = self._compute_sampled_logits(
|
||||||
|
weights=weights,
|
||||||
|
biases=biases,
|
||||||
|
labels=labels,
|
||||||
|
inputs=inputs,
|
||||||
|
num_true=self.num_true,
|
||||||
|
sampled_values=self.sampled_values,
|
||||||
|
subtract_log_q=True)
|
||||||
|
|
||||||
|
x = self._softmax_cross_entropy(logits, labels)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _softmax_cross_entropy(self, logits, targets):
|
||||||
|
stable_exp_logits = self.exp(logits - self.reduce_max_true(logits, 1))
|
||||||
|
pred = stable_exp_logits / self.reduce_sum_true(stable_exp_logits, 1)
|
||||||
|
return -self.reduce_sum(targets * self.log(pred + 1.0e-20), 1)
|
||||||
|
|
||||||
|
def _compute_sampled_logits(self, weights,
|
||||||
|
biases,
|
||||||
|
labels,
|
||||||
|
inputs,
|
||||||
|
num_true=1,
|
||||||
|
sampled_values=None,
|
||||||
|
subtract_log_q=True):
|
||||||
|
"""Helper function for SampledSoftmaxLoss functions.
|
||||||
|
|
||||||
|
Computes sampled output training logits and labels suitable
|
||||||
|
|
||||||
|
Note: In the case where num_true > 1, we assign to each target class
|
||||||
|
the target probability 1 / num_true so that the target probabilities
|
||||||
|
sum to 1 per-example.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights (Tensor): Tensor of shape `[num_classes, dim]`.
|
||||||
|
biases (Tensor): Tensor of shape `[num_classes]`.
|
||||||
|
labels (Tensor): Tensor of shape `[batch_size, num_true]`. The target classes.
|
||||||
|
inputs (Tensor): Tensor of shape `[batch_size, dim]`. The forward
|
||||||
|
activations of the input network.
|
||||||
|
num_true (int): The number of target classes per training example.
|
||||||
|
sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
|
||||||
|
`sampled_expected_count`) returned by a `UniformSampler` function.
|
||||||
|
subtract_log_q: A `bool`. whether to subtract the log expected count of
|
||||||
|
the labels in the sample to get the logits of the true labels.
|
||||||
|
Default is True.
|
||||||
|
Returns:
|
||||||
|
out_logits: `Tensor` object with shape
|
||||||
|
`[batch_size, num_true + num_sampled]`
|
||||||
|
out_labels: A Tensor object with the same shape as `out_logits`.
|
||||||
|
"""
|
||||||
|
if not labels.dtype == mstype.int32:
|
||||||
|
labels = self.cast(labels, mstype.int32)
|
||||||
|
labels = self.reshape(labels, (-1, num_true))
|
||||||
|
labels_flat = self.reshape(labels, (-1,))
|
||||||
|
|
||||||
|
# Sample the negative labels.
|
||||||
|
# sampled shape: [num_sampled] tensor
|
||||||
|
# true_expected_count shape = [batch_size, 1] tensor
|
||||||
|
# sampled_expected_count shape = [num_sampled] tensor
|
||||||
|
if sampled_values is None:
|
||||||
|
sampled_values = self.sampler(labels)
|
||||||
|
|
||||||
|
(sampled, true_expected_count, sampled_expected_count) = sampled_values
|
||||||
|
|
||||||
|
if not sampled.dtype == mstype.int32:
|
||||||
|
sampled = self.cast(sampled, mstype.int32)
|
||||||
|
all_ids = self.concat_dim0((labels_flat, sampled))
|
||||||
|
all_w = self.gather_v2(weights, all_ids, 0)
|
||||||
|
|
||||||
|
n_true = self.shape(labels_flat)[0]
|
||||||
|
n_sampled = self.shape(sampled)[0]
|
||||||
|
n_dim = self.shape(all_w)[1]
|
||||||
|
|
||||||
|
# true_w shape is [batch_size * num_true, dim]
|
||||||
|
true_w = self.slice_op(all_w, [0, 0], [n_true, n_dim])
|
||||||
|
sampled_w = self.slice_op(all_w, [n_true, 0], [n_sampled, n_dim])
|
||||||
|
sampled_logits = self.matmul(inputs, sampled_w)
|
||||||
|
|
||||||
|
all_b = self.gather_v2(biases, all_ids, 0)
|
||||||
|
true_b = self.slice_op(all_b, [0], [n_true])
|
||||||
|
sampled_b = self.slice_op(all_b, [n_true], [n_sampled])
|
||||||
|
|
||||||
|
# inputs shape is [batch_size, dim]
|
||||||
|
# true_w shape is [batch_size * num_true, dim]
|
||||||
|
# row_wise_dots is [batch_size, num_true, dim]
|
||||||
|
new_true_w_shape = (-1, num_true, n_dim)
|
||||||
|
row_wise_dots = self.mul(self.expand_dims(inputs, 1),
|
||||||
|
self.reshape(true_w, new_true_w_shape))
|
||||||
|
|
||||||
|
# We want the row-wise dot plus biases which yields a
|
||||||
|
# [batch_size, num_true] tensor of true_logits.
|
||||||
|
dots_as_matrix = self.reshape(row_wise_dots, (-1, n_dim))
|
||||||
|
true_logits = self.reshape(self.reduce_sum(dots_as_matrix, 1), (-1, num_true))
|
||||||
|
true_b = self.reshape(true_b, (-1, num_true))
|
||||||
|
true_logits += true_b
|
||||||
|
sampled_logits += sampled_b
|
||||||
|
|
||||||
|
if subtract_log_q:
|
||||||
|
# Subtract log of Q(l), prior probability that l appears in sampled.
|
||||||
|
true_logits -= self.log(true_expected_count)
|
||||||
|
sampled_logits -= self.log(sampled_expected_count)
|
||||||
|
|
||||||
|
# Construct output logits and labels. The true labels/logits start at col 0.
|
||||||
|
out_logits = self.concat_dim1((true_logits, sampled_logits))
|
||||||
|
|
||||||
|
# true_logits is a float tensor, ones_like(true_logits) is a float
|
||||||
|
# tensor of ones. We then divide by num_true to ensure the per-example
|
||||||
|
# labels sum to 1.0, i.e. form a proper probability distribution.
|
||||||
|
out_labels = self.concat_dim1((
|
||||||
|
self.ones_like(true_logits) / num_true,
|
||||||
|
self.zeros_like(sampled_logits)
|
||||||
|
))
|
||||||
|
return out_logits, out_labels
|
||||||
|
|
||||||
|
|
||||||
class BCELoss(_Loss):
|
class BCELoss(_Loss):
|
||||||
r"""
|
r"""
|
||||||
BCELoss creates a criterion to measure the Binary Cross Entropy between the true labels and predicted labels.
|
BCELoss creates a criterion to measure the Binary Cross Entropy between the true labels and predicted labels.
|
||||||
|
|
|
@ -5831,6 +5831,7 @@ class UniformSampler(PrimitiveWithInfer):
|
||||||
unique (bool): Whether all sampled classes in a batch are unique.
|
unique (bool): Whether all sampled classes in a batch are unique.
|
||||||
range_max (int): The number of possible classes.
|
range_max (int): The number of possible classes.
|
||||||
seed (int): Random seed, must be non-negative. Default: 0.
|
seed (int): Random seed, must be non-negative. Default: 0.
|
||||||
|
remove_accidental_hits (bool): Whether accidental hit is removed. Default: False.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
true_classes (int): A tensor. The target classes with a tensor shape of (batch_size, num_true).
|
true_classes (int): A tensor. The target classes with a tensor shape of (batch_size, num_true).
|
||||||
|
@ -5850,13 +5851,14 @@ class UniformSampler(PrimitiveWithInfer):
|
||||||
[1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75]
|
[1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75]
|
||||||
"""
|
"""
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, num_true, num_sampled, unique, range_max, seed=0):
|
def __init__(self, num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False):
|
||||||
"""Initialize UniformSampler"""
|
"""Initialize UniformSampler"""
|
||||||
validator.check_value_type("num_true", num_true, [int], self.name)
|
validator.check_value_type("num_true", num_true, [int], self.name)
|
||||||
validator.check_value_type("num_sampled", num_sampled, [int], self.name)
|
validator.check_value_type("num_sampled", num_sampled, [int], self.name)
|
||||||
validator.check_value_type("unique", unique, [bool], self.name)
|
validator.check_value_type("unique", unique, [bool], self.name)
|
||||||
validator.check_value_type("range_max", range_max, [int], self.name)
|
validator.check_value_type("range_max", range_max, [int], self.name)
|
||||||
validator.check_value_type("seed", seed, [int], self.name)
|
validator.check_value_type("seed", seed, [int], self.name)
|
||||||
|
validator.check_value_type("remove_accidental_hits", remove_accidental_hits, [bool], self.name)
|
||||||
validator.check("value of num_sampled", num_sampled, '', 0, Rel.GT, self.name)
|
validator.check("value of num_sampled", num_sampled, '', 0, Rel.GT, self.name)
|
||||||
if unique:
|
if unique:
|
||||||
validator.check('value of num_sampled', num_sampled, "value of range_max", range_max, Rel.LE, self.name)
|
validator.check('value of num_sampled', num_sampled, "value of range_max", range_max, Rel.LE, self.name)
|
||||||
|
|
|
@ -0,0 +1,137 @@
|
||||||
|
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
|
||||||
|
def generate_test_data(num_classes, batch_size, sampled):
|
||||||
|
dim = 10
|
||||||
|
weights_s = np.linspace(start=1, stop=num_classes * dim, num=num_classes * dim)
|
||||||
|
weights_s = np.reshape(weights_s, (num_classes, dim)).astype(np.float32) / 100.0
|
||||||
|
biases_s = np.linspace(start=1, stop=num_classes, num=num_classes)
|
||||||
|
biases_s = np.reshape(biases_s, (num_classes,)).astype(np.float32) / 100.0
|
||||||
|
hidden_acts_s = np.linspace(start=1, stop=batch_size * dim, num=batch_size * dim)
|
||||||
|
hidden_acts_s = np.reshape(
|
||||||
|
hidden_acts_s, (batch_size, dim)).astype(np.float32) / 100.0
|
||||||
|
|
||||||
|
true_exp = np.full([batch_size, 1], fill_value=0.5, dtype=np.float32)
|
||||||
|
sampled_exp = np.full([len(sampled)], fill_value=0.5, dtype=np.float32)
|
||||||
|
sampled_values = (Tensor(sampled), Tensor(true_exp), Tensor(sampled_exp))
|
||||||
|
return weights_s, biases_s, hidden_acts_s, sampled_values
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_sampled_softmax_loss_assigned_sampler():
|
||||||
|
np.random.seed(0)
|
||||||
|
num_classes = 7
|
||||||
|
batch_size = 3
|
||||||
|
labels = [0, 1, 2]
|
||||||
|
(weights, biases, hidden_acts, sampled_vals) = generate_test_data(
|
||||||
|
num_classes=num_classes,
|
||||||
|
batch_size=batch_size,
|
||||||
|
sampled=[4, 0, 2, 3])
|
||||||
|
|
||||||
|
def case_not_remove_accidental_hits():
|
||||||
|
loss = nn.SampledSoftmaxLoss(
|
||||||
|
num_sampled=4,
|
||||||
|
num_classes=num_classes,
|
||||||
|
num_true=1,
|
||||||
|
sampled_values=sampled_vals,
|
||||||
|
remove_accidental_hits=False)
|
||||||
|
|
||||||
|
got_sampled_softmax_loss = loss(Tensor(weights), Tensor(biases),
|
||||||
|
Tensor(labels), Tensor(hidden_acts))
|
||||||
|
exp_sampled_softmax_loss = np.array(
|
||||||
|
[1.7318448, 1.8015041, 1.7211525]).astype(np.float32)
|
||||||
|
assert np.allclose(got_sampled_softmax_loss.asnumpy(),
|
||||||
|
exp_sampled_softmax_loss)
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||||
|
case_not_remove_accidental_hits()
|
||||||
|
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||||
|
case_not_remove_accidental_hits()
|
||||||
|
|
||||||
|
(weights, biases, hidden_acts, sampled_vals) = generate_test_data(
|
||||||
|
num_classes=num_classes,
|
||||||
|
batch_size=batch_size,
|
||||||
|
sampled=[4, 5, 6, 3])
|
||||||
|
|
||||||
|
def case_remove_accidental_hits():
|
||||||
|
loss = nn.SampledSoftmaxLoss(
|
||||||
|
num_sampled=4,
|
||||||
|
num_classes=num_classes,
|
||||||
|
num_true=1,
|
||||||
|
sampled_values=sampled_vals,
|
||||||
|
remove_accidental_hits=True)
|
||||||
|
|
||||||
|
got_sampled_softmax_loss = loss(Tensor(weights), Tensor(biases),
|
||||||
|
Tensor(labels), Tensor(hidden_acts))
|
||||||
|
exp_sampled_softmax_loss = np.array(
|
||||||
|
[[1.85211, 2.10999, 2.20862]]).astype(np.float32)
|
||||||
|
assert np.allclose(got_sampled_softmax_loss.asnumpy(),
|
||||||
|
exp_sampled_softmax_loss)
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||||
|
case_remove_accidental_hits()
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||||
|
case_remove_accidental_hits()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_sampled_softmax_loss_none_sampler():
|
||||||
|
np.random.seed(0)
|
||||||
|
num_classes = 7
|
||||||
|
batch_size = 3
|
||||||
|
labels = [0, 1, 2]
|
||||||
|
(weights, biases, hidden_acts, _) = generate_test_data(
|
||||||
|
num_classes=num_classes,
|
||||||
|
batch_size=batch_size,
|
||||||
|
sampled=[4, 0, 2, 3])
|
||||||
|
|
||||||
|
def case_no_sampler():
|
||||||
|
loss = nn.SampledSoftmaxLoss(
|
||||||
|
num_sampled=4,
|
||||||
|
num_classes=num_classes,
|
||||||
|
num_true=1,
|
||||||
|
sampled_values=None,
|
||||||
|
seed=1,
|
||||||
|
remove_accidental_hits=False)
|
||||||
|
|
||||||
|
got_sampled_softmax_loss = loss(Tensor(weights), Tensor(biases),
|
||||||
|
Tensor(labels), Tensor(hidden_acts))
|
||||||
|
exp_sampled_softmax_loss = np.array(
|
||||||
|
[1.7345718, 1.820291, 1.7704818]).astype(np.float32)
|
||||||
|
assert np.allclose(got_sampled_softmax_loss.asnumpy(),
|
||||||
|
exp_sampled_softmax_loss)
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||||
|
case_no_sampler()
|
||||||
|
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||||
|
case_no_sampler()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_sampled_softmax_loss_assigned_sampler()
|
||||||
|
test_sampled_softmax_loss_none_sampler()
|
|
@ -35,6 +35,25 @@ def uniform_sampler(x, num_true, num_sampled, unique, range_max):
|
||||||
out1, out2, out3 = uniform_sampler_net(Tensor(x.astype(np.int32)))
|
out1, out2, out3 = uniform_sampler_net(Tensor(x.astype(np.int32)))
|
||||||
return out1.shape, out2.shape, out3.shape
|
return out1.shape, out2.shape, out3.shape
|
||||||
|
|
||||||
|
|
||||||
|
class UniformSamplerHitNet(nn.Cell):
|
||||||
|
def __init__(self, num_true, num_sampled, unique, range_max, seed, remove_accidental_hits):
|
||||||
|
super(UniformSamplerHitNet, self).__init__()
|
||||||
|
self.sampler = P.UniformSampler(num_true, num_sampled, unique, range_max, seed=seed,
|
||||||
|
remove_accidental_hits=remove_accidental_hits)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.sampler(x)
|
||||||
|
|
||||||
|
|
||||||
|
def uniform_sampler_hit(x, num_true, num_sampled, unique, range_max, seed,
|
||||||
|
remove_accidental_hits):
|
||||||
|
uniform_sampler_net = UniformSamplerHitNet(num_true, num_sampled, unique, range_max,
|
||||||
|
seed, remove_accidental_hits)
|
||||||
|
out1, out2, out3 = uniform_sampler_net(Tensor(x.astype(np.int32)))
|
||||||
|
return out1, out2, out3
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
@pytest.mark.platform_x86_gpu_training
|
@pytest.mark.platform_x86_gpu_training
|
||||||
@pytest.mark.env_onecard
|
@pytest.mark.env_onecard
|
||||||
|
@ -114,3 +133,23 @@ def test_uniform_sampler_large_random():
|
||||||
np.testing.assert_array_equal(ms1, expected_1)
|
np.testing.assert_array_equal(ms1, expected_1)
|
||||||
np.testing.assert_array_equal(ms2, expected_2)
|
np.testing.assert_array_equal(ms2, expected_2)
|
||||||
np.testing.assert_array_equal(ms3, expected_3)
|
np.testing.assert_array_equal(ms3, expected_3)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_uniform_sampler_unique_1_true_hit():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
ms1, _, _ = uniform_sampler_hit(np.array([[1]]), 1, 3, True, 4, 1, False)
|
||||||
|
expected_1 = np.array([0, 3, 1])
|
||||||
|
np.testing.assert_array_equal(ms1.asnumpy(), expected_1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_uniform_sampler_unique_1_true_no_hit():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
ms1, _, _ = uniform_sampler_hit(np.array([[1]]), 1, 3, True, 4, 1, True)
|
||||||
|
expected_1 = np.array([0, 3, 2])
|
||||||
|
np.testing.assert_array_equal(ms1.asnumpy(), expected_1)
|
||||||
|
|
Loading…
Reference in New Issue