Grad of make_range to C++ and fix some bugs

This commit is contained in:
huangchengnuo 2023-03-27 15:50:38 +08:00
parent e7d3e467fe
commit bbe31a16e7
9 changed files with 107 additions and 41 deletions

View File

@ -49,6 +49,7 @@
"mindspore/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_quant_ops.cc" "internalAstError"
"mindspore/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_scipy_ops.cc" "internalAstError"
"mindspore/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_sparse_ops.cc" "internalAstError"
"mindspore/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_sequence_ops.cc" "internalAstError"
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_cpu_kernel.cc" "internalAstError"
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/xlogy_cpu_kernel.cc" "unreadVariable"
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_cpu_kernel.cc" "unreadVariable"

View File

@ -404,6 +404,8 @@ void RegOtherBpropExpanderOps() {
REGISTER_EXPANDER_BPROP_IMPL(IOU);
}
void RegSequenceBpropExpanderOps() { REGISTER_EXPANDER_BPROP_IMPL(make_range); }
void RegBpropExpanderOps() {
RegMathBpropExpanderOps1();
RegMathBpropExpanderOps2();
@ -415,6 +417,7 @@ void RegBpropExpanderOps() {
RegClipBpropExpanderOps();
RegInnerBpropExpanderOps();
RegOtherBpropExpanderOps();
RegSequenceBpropExpanderOps();
}
} // namespace graph_bprop
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pipeline/pynative/grad/bprop_expander/bprop_irbuilder.h"
#include "pipeline/pynative/grad/bprop_expander/grad_ops/common_utils.h"
#include "include/common/utils/utils.h"
namespace mindspore::expander::bprop {
REG_BPROP_BUILDERS_BEGIN(GradSequenceOps)
REG_BPROP_BUILDER("make_range").SetBody(BODYFUNC(ib) {
auto x = ib->GetInputs();
auto id_type = ib->GetDtypeId(ib->GetInput(kIndex0));
if (id_type == TypeId::kNumberTypeInt32) {
if (x.size() == 1) {
return {ib->Value(0)};
} else if (x.size() == 2) {
return {ib->Value(0), ib->Value(0)};
} else {
return {ib->Value(0), ib->Value(0), ib->Value(0)};
}
} else {
if (x.size() == 1) {
return {ib->Value<int64_t>(0)};
} else if (x.size() == 2) {
return {ib->Value<int64_t>(0), ib->Value<int64_t>(0)};
} else {
return {ib->Value<int64_t>(0), ib->Value<int64_t>(0), ib->Value<int64_t>(0)};
}
}
});
REG_BPROP_BUILDERS_END
} // namespace mindspore::expander::bprop

View File

@ -54,8 +54,8 @@ bool CheckMakeRangeInput(const std::vector<AbstractBasePtr> &input_args, const s
auto element = input_args[i];
MS_EXCEPTION_IF_NULL(element);
auto element_type = element->BuildType();
if (element_type->type_id() != kInt64->type_id()) {
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', the " << i << "th input should be a int64 scalar but got "
if (element_type->type_id() != kInt64->type_id() && element_type->type_id() != kInt32->type_id()) {
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', the " << i << "th input should be a int scalar but got "
<< element->ToString();
}
if (!has_variable && element->BuildValue() == kValueAny) {
@ -65,7 +65,8 @@ bool CheckMakeRangeInput(const std::vector<AbstractBasePtr> &input_args, const s
return has_variable;
}
abstract::AbstractTuplePtr CalcSlidePara(const std::vector<int64_t> &values, const std::string &prim_name) {
abstract::AbstractTuplePtr CalcSlidePara(const std::vector<int64_t> &values, const std::string &prim_name,
const TypePtr &type) {
auto values_size = values.size();
int64_t start = values_size == 1 ? 0LL : values[kIndex0];
int64_t stop = values_size == 1 ? values[kIndex0] : values[kIndex1];
@ -84,7 +85,7 @@ abstract::AbstractTuplePtr CalcSlidePara(const std::vector<int64_t> &values, con
}
for (int64_t i = start; i < stop; i += step) {
args.push_back(std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(i)));
args.push_back(std::make_shared<abstract::AbstractScalar>(MakeValue(i), type));
if (i > 0 && INT_MAX - i < step) {
MS_EXCEPTION(ValueError) << "Integer overflow error occurred when traversing the range. "
<< "Please check the inputs of range.";
@ -99,7 +100,7 @@ abstract::AbstractTuplePtr CalcSlidePara(const std::vector<int64_t> &values, con
}
for (int64_t i = start; i > stop; i += step) {
args.push_back(std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(i)));
args.push_back(std::make_shared<abstract::AbstractScalar>(MakeValue(i), type));
if (i < 0 && INT_MIN - i > step) {
MS_EXCEPTION(ValueError) << "Integer overflow error occurred when traversing the range. "
<< "Please check the inputs of range.";
@ -113,9 +114,10 @@ AbstractBasePtr InferImplMakeRange(const PrimitivePtr &primitive, const Abstract
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
bool has_variable = CheckMakeRangeInput(args_spec_list, prim_name);
auto type = args_spec_list[0]->BuildType();
if (has_variable) {
// If the input to make_range has variable input, the output abs should be dynamic length sequence.
auto element = std::make_shared<abstract::AbstractScalar>(kValueAny, kInt64);
auto element = std::make_shared<abstract::AbstractScalar>(kValueAny, type);
auto ret = std::make_shared<abstract::AbstractTuple>(AbstractBasePtrList{element});
ret->CheckAndConvertToDynamicLenSequence();
return ret;
@ -124,13 +126,13 @@ AbstractBasePtr InferImplMakeRange(const PrimitivePtr &primitive, const Abstract
for (size_t i = 0; i < args_spec_list.size(); ++i) {
auto element = args_spec_list[i];
auto element_val = element->BuildValue();
if (!element_val->isa<Int64Imm>()) {
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', the " << i << "th input should be a int64 scalar but got "
if (!element_val->isa<Int64Imm>() && !element_val->isa<Int32Imm>()) {
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', the " << i << "th input should be a int scalar but got "
<< element->ToString();
}
values.push_back(element_val->cast<Int64ImmPtr>()->value());
}
return CalcSlidePara(values, prim_name);
return CalcSlidePara(values, prim_name, type);
}
} // namespace

View File

@ -44,15 +44,6 @@ def get_bprop_sequence_len(self):
return bprop
@bprop_getters.register(seq.make_range)
def get_bprop_range(self):
"""Generate bprop for make_range"""
def bprop(start, limit, delta, out, dout):
return (zeros_like(start), zeros_like(limit), zeros_like(delta))
return bprop
@bprop_getters.register(seq.SequenceAdd)
def get_bprop_sequence_add(self):
"""Generate bprop for SequenceAdd"""

View File

@ -59,7 +59,7 @@ def test_single_for_2():
y += x
return y
with pytest.raises(TypeError, match="the 0th input should be a int64 scalar"):
with pytest.raises(TypeError, match="the 0th input should be a int scalar"):
control_flow_for()

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
import pytest
from mindspore.ops.operations import _sequence_ops as seq
from mindspore import context
from mindspore.nn import Cell
from mindspore.common import mutable
@ -24,13 +23,19 @@ context.set_context(mode=context.GRAPH_MODE)
context_prepare()
class Net(Cell):
def __init__(self):
super().__init__()
self.func = seq.make_range()
class NetRange3(Cell):
def construct(self, x, y, z):
return self.func(x, y, z)
return range(x, y, z)
class NetRange2(Cell):
def construct(self, x, y):
return range(x, y)
class NetRange1(Cell):
def construct(self, x):
return range(x)
@pytest.mark.level1
@ -45,14 +50,28 @@ def test_seqence_make_range():
Expectation: the behavior is matched to python style
"""
def func(x, y, z):
def func3(x, y, z):
return tuple(range(x, y, z))
net_ms = Net()
input_x = 1
def func2(x, y):
return tuple(range(x, y))
def func1(x):
return tuple(range(x))
input_x = 10
input_y = 1000
input_z = 31
fact = TupleFactory(net_ms, func, (input_x, input_y, input_z))
net_ms = NetRange3()
fact = TupleFactory(net_ms, func3, (input_x, input_y, input_z))
fact.forward_cmp()
net_ms = NetRange2()
fact = TupleFactory(net_ms, func2, (input_x, input_y))
fact.forward_cmp()
net_ms = NetRange1()
fact = TupleFactory(net_ms, func1, (input_x,))
fact.forward_cmp()
@ -67,16 +86,22 @@ def test_seqence_make_range_grad():
Description: setitem operation on tuple type
Expectation: the behavior is matched to python style
"""
net_ms = Net()
input_x = mutable(10)
input_y = mutable(100)
input_z = mutable(3)
dout = mutable((1, 1), True)
net_ms = NetRange3()
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
print("grad out1 = ", grad_func(input_x, input_y, input_z, dout))
input_x = 10
input_y = 100
input_z = 30
dout = (1, 1, 1)
grad_out = grad_func(input_x, input_y, input_z, dout)
assert grad_out == (0, 0, 0)
net_ms = NetRange2()
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
print("grad out1 = ", grad_func(input_x, input_y, input_z, dout))
grad_out = grad_func(input_x, input_y, dout)
assert grad_out == (0, 0)
net_ms = NetRange1()
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
grad_out = grad_func(input_x, dout)
assert grad_out == (0,)

View File

@ -96,7 +96,7 @@ def test_range_with_wrong_input():
with pytest.raises(TypeError) as ex:
foo()
assert "input should be a int64 scalar but got" in str(ex.value)
assert "input should be a int scalar but got" in str(ex.value)
def test_range_with_wrong_input_2():
@ -115,4 +115,4 @@ def test_range_with_wrong_input_2():
with pytest.raises(TypeError) as ex:
foo()
assert "input should be a int64 scalar but got" in str(ex.value)
assert "input should be a int scalar but got" in str(ex.value)

View File

@ -57,7 +57,7 @@ def test_missing_return():
z = Tensor(2, mstype.int32)
with pytest.raises(TypeError) as er:
net(x, y, z)
assert "For 'make_range', the 0th input should be a int64 scalar" in str(er.value)
assert "For 'make_range', the 0th input should be a int scalar" in str(er.value)
def test_nest_function_missing_return():
@ -89,7 +89,7 @@ def test_nest_function_missing_return():
z = Tensor(2, mstype.int32)
with pytest.raises(TypeError) as er:
net(x, y, z)
assert "For 'make_range', the 0th input should be a int64 scalar" in str(er.value)
assert "For 'make_range', the 0th input should be a int scalar" in str(er.value)
@pytest.mark.skip(reason='Case will not appear for now, but may appear in the future')