!8441 Add Parallel Implements of UniformCandidateSampler

From: @huangxinjing
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-18 18:59:31 +08:00 committed by Gitee
commit ee72de1db2
9 changed files with 610 additions and 1 deletions

View File

@ -910,6 +910,21 @@ double GatherV2PCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs,
return result;
}
double UniformCandidateSamplerCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &outputs,
int64_t stage_id) const {
double result = 0.0;
Shape input0_slice_shape = inputs[0].slice_shape();
if (inputs_type_lengths_.size() != inputs.size()) {
MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size()
<< " for UniformCandidateSampler cost";
}
result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
return result;
}
double GatherV2PCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
double result = 0.0;

View File

@ -684,6 +684,38 @@ class UniqueCost : public OperatorCost {
using UniqueCostPtr = std::shared_ptr<UniqueCost>;
class UniformCandidateSamplerCost : public OperatorCost {
public:
explicit UniformCandidateSamplerCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
UniformCandidateSamplerCost() : OperatorCost(false) {}
~UniformCandidateSamplerCost() override = default;
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override {
return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
}
double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override {
return 0;
}
double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override {
return 0;
}
double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override {
return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
}
double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override;
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t) const override {
return 0.0;
}
};
using UniformCandidateSamplerCostPtr = std::shared_ptr<UniformCandidateSamplerCost>;
class GatherV2Cost : public OperatorCost {
public:
explicit GatherV2Cost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}

View File

@ -176,6 +176,7 @@ REGISTER(ExpandDimsInfo);
REGISTER(SqueezeInfo);
REGISTER(SigmoidCrossEntropyWithLogitsInfo);
REGISTER(SquareInfo);
REGISTER(UniformCandidateSamplerInfo);
REGISTER(UnsortedSegmentSumInfo);
REGISTER(UnsortedSegmentMinInfo);
REGISTER(GatherV2PInfo);

View File

@ -47,6 +47,7 @@
#include "frontend/parallel/ops_info/pack_info.h"
#include "frontend/parallel/ops_info/broadcast_to_info.h"
#include "frontend/parallel/ops_info/unique_info.h"
#include "frontend/parallel/ops_info/uniform_candidate_sampler_info.h"
#include "frontend/parallel/ops_info/reluv2_info.h"
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_

View File

@ -102,6 +102,12 @@ constexpr char END[] = "end";
constexpr char STRIDES[] = "strides";
constexpr char GROUP[] = "group";
constexpr char FUSION[] = "fusion";
constexpr char NUM_SAMPLED[] = "num_sampled";
constexpr char NUM_TRUE[] = "num_true";
constexpr char SEED[] = "seed";
constexpr char RANGE_MAX[] = "range_max";
constexpr char REMOVE_ACCIDENTAL_HITS[] = "remove_accidental_hits";
constexpr char UNIQUE_STRING[] = "unique";
constexpr char AXIS[] = "axis";
constexpr char AXES[] = "axes";
constexpr char START[] = "start";
@ -191,6 +197,7 @@ constexpr char DIV[] = "Div";
constexpr char REAL_DIV[] = "RealDiv";
constexpr char ASSIGN_SUB[] = "AssignSub";
constexpr char GREATER[] = "Greater";
constexpr char UNIFORM_CANDIDATE_SAMPLER[] = "UniformCandidateSampler";
constexpr char VIRTUAL_DATA_SET[] = "_VirtualDataset";
constexpr char VIRTUAL_DATA_SET_INFO[] = "VirtualDatasetInfo";
constexpr char SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SparseSoftmaxCrossEntropyWithLogits";

View File

@ -0,0 +1,316 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/parallel/ops_info/uniform_candidate_sampler_info.h"
#include <string>
#include <memory>
#include <vector>
#include <utility>
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#include "frontend/parallel/graph_util/generate_graph.h"
#include "frontend/parallel/context.h"
#include "pipeline/jit/resource.h"
namespace mindspore {
namespace parallel {
Status UniformCandidateSamplerInfo::GetUniformSamplerAttrInt64(const std::string &args, int64_t *value) {
auto iter = attrs_.find(args);
if (iter == attrs_.end()) {
MS_LOG(ERROR) << name_ << ": Can not find the attr for " << args;
return FAILED;
}
MS_EXCEPTION_IF_NULL(iter->second);
if (!iter->second->isa<Int64Imm>()) {
MS_LOG(ERROR) << name_ << ": The type of attr is not int, the attr is " << args;
return FAILED;
}
*value = iter->second->cast<Int64ImmPtr>()->value();
return SUCCESS;
}
Status UniformCandidateSamplerInfo::GetUniformSamplerAttrBool(const std::string &args, bool *value) {
auto iter = attrs_.find(args);
if (iter == attrs_.end()) {
MS_LOG(ERROR) << name_ << ": Can not find the attr for " << args;
return FAILED;
}
MS_EXCEPTION_IF_NULL(iter->second);
if (!iter->second->isa<BoolImm>()) {
MS_LOG(ERROR) << name_ << ": The type of attr is not bool, the attr is " << args;
return FAILED;
}
*value = iter->second->cast<BoolImmPtr>()->value();
return SUCCESS;
}
Status UniformCandidateSamplerInfo::GetAttrs() {
if (GetUniformSamplerAttrInt64(NUM_TRUE, &num_true_) != SUCCESS ||
GetUniformSamplerAttrInt64(NUM_SAMPLED, &num_sampled_) != SUCCESS ||
GetUniformSamplerAttrBool(UNIQUE_STRING, &unique_) != SUCCESS ||
GetUniformSamplerAttrInt64(RANGE_MAX, &range_max_) != SUCCESS ||
GetUniformSamplerAttrInt64(SEED, &seed_) != SUCCESS ||
GetUniformSamplerAttrBool(REMOVE_ACCIDENTAL_HITS, &remove_accidental_hits_) != SUCCESS) {
return FAILED;
} else {
MS_LOG(INFO) << name_ << ": The num_ture is " << num_true_ << " , the num_sampled is " << num_sampled_
<< ", the unique is " << unique_ << " , the range max is " << range_max_ << " , the seed is " << seed_
<< " , the remove_accidental_hits is " << remove_accidental_hits_;
}
return SUCCESS;
}
Status UniformCandidateSamplerInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy);
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy";
return FAILED;
}
std::vector<Dimensions> stra = strategy->GetInputDim();
if (stra.empty()) {
MS_LOG(ERROR) << name_ << ": The strategy is empty";
return FAILED;
}
Dimensions input_strategy = stra.at(0);
if (remove_accidental_hits_) {
bool shard = std::any_of(input_strategy.begin(), input_strategy.end(), [](int64_t v) { return v > 1; });
if (shard) {
MS_LOG(ERROR) << name_ << ": When remove accidental_hits is true, the operation only supports (1,1) shard.";
return FAILED;
}
}
return SUCCESS;
}
Status UniformCandidateSamplerInfo::InferDevMatrixShape() {
MS_EXCEPTION_IF_NULL(strategy_);
std::vector<Dimensions> stra = strategy_->GetInputDim();
if (stra.empty()) {
MS_LOG(ERROR) << name_ << ": The strategy is empty";
return FAILED;
}
dev_matrix_shape_ = stra[0];
return SUCCESS;
}
// There are three outputs
// sampled_candidates, true_expected_count, sampled_expected_count
// the sampled_candidates and sampled_expected_count is recomputed on each device with tensor map [-1]
// only true_expected_count is shard
Status UniformCandidateSamplerInfo::InferTensorMap() {
TensorMap tensor_map;
TensorMap sampled_tensor_map = {-1};
if (inputs_shape_.empty()) {
MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
return FAILED;
}
int32_t size = SizeToInt(inputs_shape_[0].size());
for (int i = 0; i < size; ++i) {
tensor_map.push_back(size - i - 1);
}
inputs_tensor_map_.push_back(tensor_map);
// Output 1 sampled_candidates
outputs_tensor_map_.push_back(sampled_tensor_map);
// Output 2 true_expected_count
outputs_tensor_map_.push_back(tensor_map);
// Output 3 sampled_expected_count
outputs_tensor_map_.push_back(sampled_tensor_map);
return SUCCESS;
}
// The UniformCandidateSampler is not supported to be the last op of the net
Status UniformCandidateSamplerInfo::InferAsLossDivisor() {
as_loss_divisor_ = 1;
return SUCCESS;
}
Status UniformCandidateSamplerInfo::InferMirrorOps() {
mirror_ops_.clear();
if (inputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": The inputs tensor map is empty";
return FAILED;
}
Shape input_tensor_map = inputs_tensor_map_[0];
std::vector<Group> group;
if (CreateGroupByTensorMap(input_tensor_map, &group) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Create group for input failed.";
return FAILED;
}
OperatorVector mirror_op;
if (group.empty()) {
MS_LOG(INFO) << name_ << ": The mirror group is empty.";
return SUCCESS;
} else {
mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
mirror_ops_.push_back(mirror_op);
std::string group_name = group[0].name();
MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name;
}
return SUCCESS;
}
Status UniformCandidateSamplerInfo::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": Invalid args";
return FAILED;
}
TensorLayout input_layout, output_layout;
// infer tensor layout
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
return FAILED;
}
TensorInfo input_tensor_info(input_layout);
inputs_tensor_info_.push_back(input_tensor_info);
for (size_t i = 0; i < outputs_shape_.size(); ++i) {
// infer tensor layout
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[i], outputs_shape_[i]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
return FAILED;
}
TensorInfo output_tensor_info(output_layout);
outputs_tensor_info_.push_back(output_tensor_info);
}
return SUCCESS;
}
Status UniformCandidateSamplerInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
return SetCostUnderStrategyBase(strategy);
}
Status UniformCandidateSamplerInfo::GenerateStrategies(int64_t stage_id) {
if (InferAttrs() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer attrs failed";
return FAILED;
}
if (inputs_shape_.empty()) {
MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
return FAILED;
}
Shape input_split = {};
Shapes splittable_input = {};
size_t splitable_value = 1;
if (remove_accidental_hits_) {
splitable_value = 0;
}
for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
input_split.push_back(splitable_value);
}
splittable_input.push_back(input_split);
std::vector<StrategyPtr> sp_vector;
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_input, &sp_vector) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Generate strategies failed";
return FAILED;
}
size_t success = 0;
for (auto &sp : sp_vector) {
PrintStrategy(sp);
if (SetCostUnderStrategy(sp) == SUCCESS) {
success++;
MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy.";
PrintStrategy(sp);
}
}
return SUCCESS;
}
std::shared_ptr<Strategys> UniformCandidateSamplerInfo::GenerateBatchStrategies() {
if (GetAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
}
CheckGlobalDeviceManager();
Dimensions input_strategy(inputs_shape_[0].size(), 1);
Strategys strategy_v = {input_strategy};
return std::make_shared<Strategys>(strategy_v);
}
Status UniformCandidateSamplerInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}
Status UniformCandidateSamplerInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
ReplaceGraphPtr UniformCandidateSamplerInfo::replace_graph(const CNodePtr &cnode) {
auto input_strategy = strategy_->GetInputDim().at(0);
// Only when the axis-1 is sharded, we need to modify the attribute
if (input_strategy.size() == 2 && input_strategy[1] > 1) {
if (ComputeReplaceGraph(cnode) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
}
}
return replace_graph_;
}
Status UniformCandidateSamplerInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
GenerateGraph gen_g = GenerateGraph();
auto input_strategy = strategy_->GetInputDim().at(0);
if (gen_g.Init(cnode) != SUCCESS) {
MS_LOG(ERROR) << "GenerateGraph Init failed";
return FAILED;
}
auto slice_num_true = num_true_ / input_strategy[1];
// Get the attributes of the UnsortedSegmentMin
Attr attr_num_ture = std::make_pair(NUM_TRUE, MakeValue(slice_num_true));
Attr attr_num_sampled = std::make_pair(NUM_SAMPLED, MakeValue(num_sampled_));
Attr attr_unique = std::make_pair(UNIQUE_STRING, MakeValue(unique_));
Attr attr_range_max = std::make_pair(RANGE_MAX, MakeValue(range_max_));
Attr attr_seed = std::make_pair(SEED, MakeValue(seed_));
Attr attr_remove_accidental_hits = std::make_pair(REMOVE_ACCIDENTAL_HITS, MakeValue(remove_accidental_hits_));
OperatorAttrs attrs = {attr_num_ture, attr_num_sampled, attr_unique,
attr_range_max, attr_seed, attr_remove_accidental_hits};
auto new_sampler_op = gen_g.PushBack({gen_g.NewOpInst(UNIFORM_CANDIDATE_SAMPLER, attrs), gen_g.virtual_input_node()});
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(new_sampler_op, 1)};
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
std::make_pair(input_nodes, new_sampler_op));
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,76 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNFORM_CANDIDATE_SAMPLER_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNFORM_CANDIDATE_SAMPLER_INFO_H_
#include <string>
#include <memory>
#include "ir/value.h"
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
namespace mindspore {
namespace parallel {
constexpr size_t UNIFORM_CANDIDATE_SAMPLER_INPUTS_SIZE = 2;
class UniformCandidateSamplerInfo : public OperatorInfo {
public:
UniformCandidateSamplerInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs,
std::make_shared<UniformCandidateSamplerCost>()),
num_sampled_(0),
num_true_(0),
unique_(false),
range_max_(0),
seed_(0),
remove_accidental_hits_(false) {}
~UniformCandidateSamplerInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
Status GenerateStrategies(int64_t) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
Status InferAsLossDivisor() override;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
protected:
Status GetAttrs() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status ComputeReplaceGraph(const CNodePtr &cnode);
private:
Status GetUniformSamplerAttrBool(const std::string &argsy, bool *value);
Status GetUniformSamplerAttrInt64(const std::string &args, int64_t *value);
int64_t num_sampled_;
int64_t num_true_;
bool unique_;
int64_t range_max_;
int64_t seed_;
bool remove_accidental_hits_;
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNFORM_CANDIDATE_SAMPLER_INFO_H_

View File

@ -317,7 +317,7 @@ bool IsSplittableOperator(const std::string &op_name) {
EXPM1, LOG1P, SIN, SINH, TAN, RSQRT, INV, RECIPROCAL, ROUND, FLOOR, SIGN, ERF, ERFC, ZEROSLIKE, ONESLIKE,
BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2,
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE};
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER};
// clang-format on
auto iter = splittable_op.find(op_name);

View File

@ -0,0 +1,161 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.context as context
from mindspore import Tensor, Parameter
import mindspore.nn as nn
from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, Momentum
from mindspore.ops import operations as P
class Net(nn.Cell):
def __init__(self, embedding_weight, num_true, num_sampled, unique, range_max, seed, remove_accidential,
strategy1=None):
super(Net, self).__init__()
self.sampler = P.UniformCandidateSampler(num_true, num_sampled, unique, range_max, seed,
remove_accidential)
if strategy1:
self.sampler.shard(strategy1)
self.embedding_table = Parameter(embedding_weight, "embedding_weight")
self.gatherv2 = P.GatherV2()
self.reduce_sum = P.ReduceSum()
self.reduce_sum2 = P.ReduceSum()
self.reduce_sum3 = P.ReduceSum()
def construct(self, x):
out1, out2, out3 = self.sampler(x)
lookup = self.gatherv2(self.embedding_table, out1, 0)
loss = out1 - out3
loss = self.reduce_sum(loss, (0,))
loss2 = self.reduce_sum2(lookup, (0, 1))
loss3 = self.reduce_sum3(out2, (0, 1))
loss4 = loss + loss2 + loss3
return loss4
class Net2(nn.Cell):
def __init__(self, mul_weight, num_true, num_sampled, unique, range_max, seed, remove_accidential,
strategy1=None):
super(Net2, self).__init__()
self.sampler = P.UniformCandidateSampler(num_true, num_sampled, unique, range_max, seed,
remove_accidential)
self.cast = P.Cast()
self.weight = Parameter(mul_weight, "w1")
self.mul = P.Mul()
if strategy1:
self.sampler.shard(strategy1)
def construct(self, x):
x = self.mul(x, self.weight)
x = self.cast(x, ms.int32)
_, out2, _ = self.sampler(x)
return out2
_w = Tensor(np.ones([48, 16]), dtype=ms.float32)
_w1 = Tensor(np.ones([96, 64]), dtype=ms.float32)
_x = Tensor(np.ones([48, 16]), dtype=ms.int32)
def compile_net(net):
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_executor.compile(train_net, _x)
context.reset_auto_parallel_context()
def test_uniform_candidate_sampler_no_full_0d_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((4, 1),)
net = Net(_w1, num_true=16, num_sampled=16, unique=True, range_max=20, seed=1,
remove_accidential=False, strategy1=strategy1)
compile_net(net)
def test_uniform_candidate_sampler_no_full_1d_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 4),)
net = Net(_w1, num_true=16, num_sampled=16, unique=True, range_max=20, seed=1,
remove_accidential=False, strategy1=strategy1)
compile_net(net)
def test_uniform_candidate_sampler_full_0d_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1),)
net = Net(_w1, num_true=16, num_sampled=16, unique=True, range_max=20, seed=1,
remove_accidential=False, strategy1=strategy1)
compile_net(net)
def test_uniform_candidate_sampler_full_1d_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 8),)
net = Net(_w1, num_true=16, num_sampled=16, unique=True, range_max=20, seed=1,
remove_accidential=False, strategy1=strategy1)
compile_net(net)
def test_uniform_candidate_sampler_full_1d_unqiue_false():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 8),)
net = Net(_w1, num_true=16, num_sampled=16, unique=False, range_max=20, seed=1,
remove_accidential=False, strategy1=strategy1)
compile_net(net)
def test_uniform_candidate_sampler_auto_parllel():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net(_w1, num_true=16, num_sampled=16, unique=False, range_max=20, seed=1,
remove_accidential=False, strategy1=None)
compile_net(net)
def test_uniform_candidate_sampler_auto_parllel_unqiue_true():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net(_w1, num_true=16, num_sampled=16, unique=True, range_max=20, seed=1,
remove_accidential=False, strategy1=None)
compile_net(net)
def test_uniform_candidate_sampler_auto_parllel_remove_true():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net(_w1, num_true=16, num_sampled=16, unique=True, range_max=20, seed=1,
remove_accidential=True, strategy1=None)
compile_net(net)
def test_uniform_candidate_sampler_full_1d_remove_true():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 8),)
net = Net(_w1, num_true=16, num_sampled=16, unique=False, range_max=20, seed=1,
remove_accidential=True, strategy1=strategy1)
with pytest.raises(RuntimeError):
compile_net(net)
def test_uniform_candidate_sampler_as_final():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 8),)
net = Net2(_w, num_true=16, num_sampled=16, unique=False, range_max=20, seed=1, remove_accidential=False,
strategy1=strategy1)
with pytest.raises(RuntimeError):
compile_net(net)