forked from mindspore-Ecosystem/mindspore
Add backend check for RandomChoiceWithMask
This commit is contained in:
parent
753aa475d9
commit
b0b79ef8fb
|
@ -19,6 +19,18 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
int64_t RandomChoiceWithMaskInfo::SEED_NUM = 1;
|
||||
|
||||
Status RandomChoiceWithMaskInfo::GetAttrs() {
|
||||
if (attrs_.find(SEED) != attrs_.end()) {
|
||||
seed_ = GetValue<int64_t>(attrs_[SEED]);
|
||||
}
|
||||
if (attrs_.find(SEED2) != attrs_.end()) {
|
||||
seed2_ = GetValue<int64_t>(attrs_[SEED2]);
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status RandomChoiceWithMaskInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
|
||||
|
@ -71,5 +83,39 @@ Status RandomChoiceWithMaskInfo::InferAsLossDivisor() {
|
|||
<< as_loss_divisor_;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void RandomChoiceWithMaskInfo::ReplaceNodeInputOrAttrs() {
|
||||
if (seed_ != 0 || seed2_ != 0) {
|
||||
return;
|
||||
}
|
||||
if (cnode_->HasAttr(SEED)) {
|
||||
cnode_->EraseAttr(SEED);
|
||||
}
|
||||
if (cnode_->HasAttr(SEED2)) {
|
||||
cnode_->EraseAttr(SEED2);
|
||||
}
|
||||
cnode_->AddAttr(SEED, MakeValue(SEED_NUM));
|
||||
cnode_->AddAttr(SEED2, MakeValue(SEED_NUM));
|
||||
++SEED_NUM;
|
||||
}
|
||||
|
||||
void RandomChoiceWithMaskInfo::CheckGPUBackend() {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (backend != kGPUDevice) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The backend is " << backend << " , only support on GPU backend now.";
|
||||
}
|
||||
}
|
||||
|
||||
Status RandomChoiceWithMaskInfo::Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) {
|
||||
CheckGPUBackend();
|
||||
return OperatorInfo::Init(in_strategy, out_strategy);
|
||||
}
|
||||
|
||||
Status RandomChoiceWithMaskInfo::InitForCostModel(const StrategyPtr &strategy, const StrategyPtr &out_strategy) {
|
||||
CheckGPUBackend();
|
||||
return OperatorInfo::InitForCostModel(strategy, out_strategy);
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,18 +31,30 @@ class RandomChoiceWithMaskInfo : public OperatorInfo {
|
|||
public:
|
||||
RandomChoiceWithMaskInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<RandomChoicWithMaskCost>()) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<RandomChoicWithMaskCost>()),
|
||||
seed_(0),
|
||||
seed2_(0) {}
|
||||
~RandomChoiceWithMaskInfo() = default;
|
||||
Status Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr &strategy, const StrategyPtr &out_strategy) override;
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
|
||||
void ReplaceNodeInputOrAttrs() override;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override { return SUCCESS; }
|
||||
Status GetAttrs() override;
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferAsLossDivisor() override;
|
||||
|
||||
private:
|
||||
void CheckGPUBackend();
|
||||
|
||||
int64_t seed_;
|
||||
int64_t seed2_;
|
||||
static int64_t SEED_NUM;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -51,6 +51,7 @@ def test_auto_parallel_random_choice_with_mask():
|
|||
Description: auto parallel
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_context(device_target="GPU")
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net()
|
||||
compile_net(net, _input_x)
|
||||
|
@ -62,9 +63,24 @@ def test_random_choice_with_mask_wrong_strategy():
|
|||
Description: illegal strategy
|
||||
Expectation: raise RuntimeError
|
||||
"""
|
||||
context.set_context(device_target="GPU")
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((8, 1),)
|
||||
net = Net(strategy)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, _input_x)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_random_choice_with_mask_not_gpu():
|
||||
"""
|
||||
Feature: RandomChoiceWithMask
|
||||
Description: not compile with gpu backend
|
||||
Expectation: raise RuntimeError
|
||||
"""
|
||||
context.set_context(device_target="Ascend")
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net()
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, _input_x)
|
||||
context.reset_auto_parallel_context()
|
||||
|
|
Loading…
Reference in New Issue