Grad of make_range to C++ and fix some bugs
This commit is contained in:
parent
e7d3e467fe
commit
bbe31a16e7
|
@ -49,6 +49,7 @@
|
|||
"mindspore/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_quant_ops.cc" "internalAstError"
|
||||
"mindspore/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_scipy_ops.cc" "internalAstError"
|
||||
"mindspore/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_sparse_ops.cc" "internalAstError"
|
||||
"mindspore/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_sequence_ops.cc" "internalAstError"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_cpu_kernel.cc" "internalAstError"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/xlogy_cpu_kernel.cc" "unreadVariable"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_cpu_kernel.cc" "unreadVariable"
|
||||
|
|
|
@ -404,6 +404,8 @@ void RegOtherBpropExpanderOps() {
|
|||
REGISTER_EXPANDER_BPROP_IMPL(IOU);
|
||||
}
|
||||
|
||||
void RegSequenceBpropExpanderOps() { REGISTER_EXPANDER_BPROP_IMPL(make_range); }
|
||||
|
||||
void RegBpropExpanderOps() {
|
||||
RegMathBpropExpanderOps1();
|
||||
RegMathBpropExpanderOps2();
|
||||
|
@ -415,6 +417,7 @@ void RegBpropExpanderOps() {
|
|||
RegClipBpropExpanderOps();
|
||||
RegInnerBpropExpanderOps();
|
||||
RegOtherBpropExpanderOps();
|
||||
RegSequenceBpropExpanderOps();
|
||||
}
|
||||
} // namespace graph_bprop
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pipeline/pynative/grad/bprop_expander/bprop_irbuilder.h"
|
||||
#include "pipeline/pynative/grad/bprop_expander/grad_ops/common_utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
||||
namespace mindspore::expander::bprop {
|
||||
REG_BPROP_BUILDERS_BEGIN(GradSequenceOps)
|
||||
REG_BPROP_BUILDER("make_range").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInputs();
|
||||
auto id_type = ib->GetDtypeId(ib->GetInput(kIndex0));
|
||||
if (id_type == TypeId::kNumberTypeInt32) {
|
||||
if (x.size() == 1) {
|
||||
return {ib->Value(0)};
|
||||
} else if (x.size() == 2) {
|
||||
return {ib->Value(0), ib->Value(0)};
|
||||
} else {
|
||||
return {ib->Value(0), ib->Value(0), ib->Value(0)};
|
||||
}
|
||||
} else {
|
||||
if (x.size() == 1) {
|
||||
return {ib->Value<int64_t>(0)};
|
||||
} else if (x.size() == 2) {
|
||||
return {ib->Value<int64_t>(0), ib->Value<int64_t>(0)};
|
||||
} else {
|
||||
return {ib->Value<int64_t>(0), ib->Value<int64_t>(0), ib->Value<int64_t>(0)};
|
||||
}
|
||||
}
|
||||
});
|
||||
REG_BPROP_BUILDERS_END
|
||||
} // namespace mindspore::expander::bprop
|
|
@ -54,8 +54,8 @@ bool CheckMakeRangeInput(const std::vector<AbstractBasePtr> &input_args, const s
|
|||
auto element = input_args[i];
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
auto element_type = element->BuildType();
|
||||
if (element_type->type_id() != kInt64->type_id()) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', the " << i << "th input should be a int64 scalar but got "
|
||||
if (element_type->type_id() != kInt64->type_id() && element_type->type_id() != kInt32->type_id()) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', the " << i << "th input should be a int scalar but got "
|
||||
<< element->ToString();
|
||||
}
|
||||
if (!has_variable && element->BuildValue() == kValueAny) {
|
||||
|
@ -65,7 +65,8 @@ bool CheckMakeRangeInput(const std::vector<AbstractBasePtr> &input_args, const s
|
|||
return has_variable;
|
||||
}
|
||||
|
||||
abstract::AbstractTuplePtr CalcSlidePara(const std::vector<int64_t> &values, const std::string &prim_name) {
|
||||
abstract::AbstractTuplePtr CalcSlidePara(const std::vector<int64_t> &values, const std::string &prim_name,
|
||||
const TypePtr &type) {
|
||||
auto values_size = values.size();
|
||||
int64_t start = values_size == 1 ? 0LL : values[kIndex0];
|
||||
int64_t stop = values_size == 1 ? values[kIndex0] : values[kIndex1];
|
||||
|
@ -84,7 +85,7 @@ abstract::AbstractTuplePtr CalcSlidePara(const std::vector<int64_t> &values, con
|
|||
}
|
||||
|
||||
for (int64_t i = start; i < stop; i += step) {
|
||||
args.push_back(std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(i)));
|
||||
args.push_back(std::make_shared<abstract::AbstractScalar>(MakeValue(i), type));
|
||||
if (i > 0 && INT_MAX - i < step) {
|
||||
MS_EXCEPTION(ValueError) << "Integer overflow error occurred when traversing the range. "
|
||||
<< "Please check the inputs of range.";
|
||||
|
@ -99,7 +100,7 @@ abstract::AbstractTuplePtr CalcSlidePara(const std::vector<int64_t> &values, con
|
|||
}
|
||||
|
||||
for (int64_t i = start; i > stop; i += step) {
|
||||
args.push_back(std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(i)));
|
||||
args.push_back(std::make_shared<abstract::AbstractScalar>(MakeValue(i), type));
|
||||
if (i < 0 && INT_MIN - i > step) {
|
||||
MS_EXCEPTION(ValueError) << "Integer overflow error occurred when traversing the range. "
|
||||
<< "Please check the inputs of range.";
|
||||
|
@ -113,9 +114,10 @@ AbstractBasePtr InferImplMakeRange(const PrimitivePtr &primitive, const Abstract
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
bool has_variable = CheckMakeRangeInput(args_spec_list, prim_name);
|
||||
auto type = args_spec_list[0]->BuildType();
|
||||
if (has_variable) {
|
||||
// If the input to make_range has variable input, the output abs should be dynamic length sequence.
|
||||
auto element = std::make_shared<abstract::AbstractScalar>(kValueAny, kInt64);
|
||||
auto element = std::make_shared<abstract::AbstractScalar>(kValueAny, type);
|
||||
auto ret = std::make_shared<abstract::AbstractTuple>(AbstractBasePtrList{element});
|
||||
ret->CheckAndConvertToDynamicLenSequence();
|
||||
return ret;
|
||||
|
@ -124,13 +126,13 @@ AbstractBasePtr InferImplMakeRange(const PrimitivePtr &primitive, const Abstract
|
|||
for (size_t i = 0; i < args_spec_list.size(); ++i) {
|
||||
auto element = args_spec_list[i];
|
||||
auto element_val = element->BuildValue();
|
||||
if (!element_val->isa<Int64Imm>()) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', the " << i << "th input should be a int64 scalar but got "
|
||||
if (!element_val->isa<Int64Imm>() && !element_val->isa<Int32Imm>()) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', the " << i << "th input should be a int scalar but got "
|
||||
<< element->ToString();
|
||||
}
|
||||
values.push_back(element_val->cast<Int64ImmPtr>()->value());
|
||||
}
|
||||
return CalcSlidePara(values, prim_name);
|
||||
return CalcSlidePara(values, prim_name, type);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -44,15 +44,6 @@ def get_bprop_sequence_len(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(seq.make_range)
|
||||
def get_bprop_range(self):
|
||||
"""Generate bprop for make_range"""
|
||||
def bprop(start, limit, delta, out, dout):
|
||||
return (zeros_like(start), zeros_like(limit), zeros_like(delta))
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(seq.SequenceAdd)
|
||||
def get_bprop_sequence_add(self):
|
||||
"""Generate bprop for SequenceAdd"""
|
||||
|
|
|
@ -59,7 +59,7 @@ def test_single_for_2():
|
|||
y += x
|
||||
return y
|
||||
|
||||
with pytest.raises(TypeError, match="the 0th input should be a int64 scalar"):
|
||||
with pytest.raises(TypeError, match="the 0th input should be a int scalar"):
|
||||
control_flow_for()
|
||||
|
||||
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
from mindspore.ops.operations import _sequence_ops as seq
|
||||
from mindspore import context
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.common import mutable
|
||||
|
@ -24,13 +23,19 @@ context.set_context(mode=context.GRAPH_MODE)
|
|||
context_prepare()
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.func = seq.make_range()
|
||||
|
||||
class NetRange3(Cell):
|
||||
def construct(self, x, y, z):
|
||||
return self.func(x, y, z)
|
||||
return range(x, y, z)
|
||||
|
||||
|
||||
class NetRange2(Cell):
|
||||
def construct(self, x, y):
|
||||
return range(x, y)
|
||||
|
||||
|
||||
class NetRange1(Cell):
|
||||
def construct(self, x):
|
||||
return range(x)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
|
@ -45,14 +50,28 @@ def test_seqence_make_range():
|
|||
Expectation: the behavior is matched to python style
|
||||
"""
|
||||
|
||||
def func(x, y, z):
|
||||
def func3(x, y, z):
|
||||
return tuple(range(x, y, z))
|
||||
|
||||
net_ms = Net()
|
||||
input_x = 1
|
||||
def func2(x, y):
|
||||
return tuple(range(x, y))
|
||||
|
||||
def func1(x):
|
||||
return tuple(range(x))
|
||||
|
||||
input_x = 10
|
||||
input_y = 1000
|
||||
input_z = 31
|
||||
fact = TupleFactory(net_ms, func, (input_x, input_y, input_z))
|
||||
net_ms = NetRange3()
|
||||
fact = TupleFactory(net_ms, func3, (input_x, input_y, input_z))
|
||||
fact.forward_cmp()
|
||||
|
||||
net_ms = NetRange2()
|
||||
fact = TupleFactory(net_ms, func2, (input_x, input_y))
|
||||
fact.forward_cmp()
|
||||
|
||||
net_ms = NetRange1()
|
||||
fact = TupleFactory(net_ms, func1, (input_x,))
|
||||
fact.forward_cmp()
|
||||
|
||||
|
||||
|
@ -67,16 +86,22 @@ def test_seqence_make_range_grad():
|
|||
Description: setitem operation on tuple type
|
||||
Expectation: the behavior is matched to python style
|
||||
"""
|
||||
net_ms = Net()
|
||||
input_x = mutable(10)
|
||||
input_y = mutable(100)
|
||||
input_z = mutable(3)
|
||||
dout = mutable((1, 1), True)
|
||||
|
||||
net_ms = NetRange3()
|
||||
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
||||
print("grad out1 = ", grad_func(input_x, input_y, input_z, dout))
|
||||
input_x = 10
|
||||
input_y = 100
|
||||
input_z = 30
|
||||
dout = (1, 1, 1)
|
||||
grad_out = grad_func(input_x, input_y, input_z, dout)
|
||||
assert grad_out == (0, 0, 0)
|
||||
|
||||
net_ms = NetRange2()
|
||||
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
||||
print("grad out1 = ", grad_func(input_x, input_y, input_z, dout))
|
||||
grad_out = grad_func(input_x, input_y, dout)
|
||||
assert grad_out == (0, 0)
|
||||
|
||||
net_ms = NetRange1()
|
||||
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
||||
grad_out = grad_func(input_x, dout)
|
||||
assert grad_out == (0,)
|
||||
|
|
|
@ -96,7 +96,7 @@ def test_range_with_wrong_input():
|
|||
|
||||
with pytest.raises(TypeError) as ex:
|
||||
foo()
|
||||
assert "input should be a int64 scalar but got" in str(ex.value)
|
||||
assert "input should be a int scalar but got" in str(ex.value)
|
||||
|
||||
|
||||
def test_range_with_wrong_input_2():
|
||||
|
@ -115,4 +115,4 @@ def test_range_with_wrong_input_2():
|
|||
|
||||
with pytest.raises(TypeError) as ex:
|
||||
foo()
|
||||
assert "input should be a int64 scalar but got" in str(ex.value)
|
||||
assert "input should be a int scalar but got" in str(ex.value)
|
||||
|
|
|
@ -57,7 +57,7 @@ def test_missing_return():
|
|||
z = Tensor(2, mstype.int32)
|
||||
with pytest.raises(TypeError) as er:
|
||||
net(x, y, z)
|
||||
assert "For 'make_range', the 0th input should be a int64 scalar" in str(er.value)
|
||||
assert "For 'make_range', the 0th input should be a int scalar" in str(er.value)
|
||||
|
||||
|
||||
def test_nest_function_missing_return():
|
||||
|
@ -89,7 +89,7 @@ def test_nest_function_missing_return():
|
|||
z = Tensor(2, mstype.int32)
|
||||
with pytest.raises(TypeError) as er:
|
||||
net(x, y, z)
|
||||
assert "For 'make_range', the 0th input should be a int64 scalar" in str(er.value)
|
||||
assert "For 'make_range', the 0th input should be a int scalar" in str(er.value)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Case will not appear for now, but may appear in the future')
|
||||
|
|
Loading…
Reference in New Issue