Add backend check for RandomChoiceWithMask

This commit is contained in:
liuluobin 2022-02-16 17:31:48 +08:00
parent 753aa475d9
commit b0b79ef8fb
3 changed files with 76 additions and 2 deletions

View File

@ -19,6 +19,18 @@
namespace mindspore {
namespace parallel {
int64_t RandomChoiceWithMaskInfo::SEED_NUM = 1;
Status RandomChoiceWithMaskInfo::GetAttrs() {
if (attrs_.find(SEED) != attrs_.end()) {
seed_ = GetValue<int64_t>(attrs_[SEED]);
}
if (attrs_.find(SEED2) != attrs_.end()) {
seed2_ = GetValue<int64_t>(attrs_[SEED2]);
}
return SUCCESS;
}
Status RandomChoiceWithMaskInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
@ -71,5 +83,39 @@ Status RandomChoiceWithMaskInfo::InferAsLossDivisor() {
<< as_loss_divisor_;
return SUCCESS;
}
void RandomChoiceWithMaskInfo::ReplaceNodeInputOrAttrs() {
if (seed_ != 0 || seed2_ != 0) {
return;
}
if (cnode_->HasAttr(SEED)) {
cnode_->EraseAttr(SEED);
}
if (cnode_->HasAttr(SEED2)) {
cnode_->EraseAttr(SEED2);
}
cnode_->AddAttr(SEED, MakeValue(SEED_NUM));
cnode_->AddAttr(SEED2, MakeValue(SEED_NUM));
++SEED_NUM;
}
void RandomChoiceWithMaskInfo::CheckGPUBackend() {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (backend != kGPUDevice) {
MS_LOG(EXCEPTION) << name_ << ": The backend is " << backend << " , only support on GPU backend now.";
}
}
Status RandomChoiceWithMaskInfo::Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) {
CheckGPUBackend();
return OperatorInfo::Init(in_strategy, out_strategy);
}
Status RandomChoiceWithMaskInfo::InitForCostModel(const StrategyPtr &strategy, const StrategyPtr &out_strategy) {
CheckGPUBackend();
return OperatorInfo::InitForCostModel(strategy, out_strategy);
}
} // namespace parallel
} // namespace mindspore

View File

@ -31,18 +31,30 @@ class RandomChoiceWithMaskInfo : public OperatorInfo {
public:
RandomChoiceWithMaskInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<RandomChoicWithMaskCost>()) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<RandomChoicWithMaskCost>()),
seed_(0),
seed2_(0) {}
~RandomChoiceWithMaskInfo() = default;
Status Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) override;
Status InitForCostModel(const StrategyPtr &strategy, const StrategyPtr &out_strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
void ReplaceNodeInputOrAttrs() override;
protected:
Status GetAttrs() override { return SUCCESS; }
Status GetAttrs() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferAsLossDivisor() override;
private:
void CheckGPUBackend();
int64_t seed_;
int64_t seed2_;
static int64_t SEED_NUM;
};
} // namespace parallel
} // namespace mindspore

View File

@ -51,6 +51,7 @@ def test_auto_parallel_random_choice_with_mask():
Description: auto parallel
Expectation: compile success
"""
context.set_context(device_target="GPU")
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net()
compile_net(net, _input_x)
@ -62,9 +63,24 @@ def test_random_choice_with_mask_wrong_strategy():
Description: illegal strategy
Expectation: raise RuntimeError
"""
context.set_context(device_target="GPU")
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy = ((8, 1),)
net = Net(strategy)
with pytest.raises(RuntimeError):
compile_net(net, _input_x)
context.reset_auto_parallel_context()
def test_random_choice_with_mask_not_gpu():
"""
Feature: RandomChoiceWithMask
Description: not compile with gpu backend
Expectation: raise RuntimeError
"""
context.set_context(device_target="Ascend")
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net()
with pytest.raises(RuntimeError):
compile_net(net, _input_x)
context.reset_auto_parallel_context()