!28837 fix a bug of onehot

Merge pull request !28837 from bichaoyang/master_1
This commit is contained in:
i-robot 2022-01-17 07:32:03 +00:00 committed by Gitee
commit dcf3095302
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 2 additions and 1 deletions

View File

@ -208,7 +208,8 @@ ReplaceGraphPtr OneHotInfo::replace_graph(const CNodePtr &cnode) {
}
std::vector<StrategyPtr> OneHotInfo::GenerateOpStrategies(int64_t stage_id) {
Shapes splittable_inputs = {{1, 1}, {}, {}};
Shape input0_split(outputs_shape_[0].size(), 1);
Shapes splittable_inputs = {input0_split, {}, {}};
std::vector<StrategyPtr> sp_vector;
if (inputs_shape_.size() != 3) {
MS_LOG(EXCEPTION) << name_ << ": inputs_shape_ size must be 3, but is " << inputs_shape_.size();