Grad of make_range to C++ and fix some bugs
This commit is contained in:
parent
e7d3e467fe
commit
bbe31a16e7
|
@ -49,6 +49,7 @@
|
||||||
"mindspore/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_quant_ops.cc" "internalAstError"
|
"mindspore/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_quant_ops.cc" "internalAstError"
|
||||||
"mindspore/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_scipy_ops.cc" "internalAstError"
|
"mindspore/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_scipy_ops.cc" "internalAstError"
|
||||||
"mindspore/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_sparse_ops.cc" "internalAstError"
|
"mindspore/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_sparse_ops.cc" "internalAstError"
|
||||||
|
"mindspore/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_sequence_ops.cc" "internalAstError"
|
||||||
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_cpu_kernel.cc" "internalAstError"
|
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_cpu_kernel.cc" "internalAstError"
|
||||||
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/xlogy_cpu_kernel.cc" "unreadVariable"
|
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/xlogy_cpu_kernel.cc" "unreadVariable"
|
||||||
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_cpu_kernel.cc" "unreadVariable"
|
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_cpu_kernel.cc" "unreadVariable"
|
||||||
|
|
|
@ -404,6 +404,8 @@ void RegOtherBpropExpanderOps() {
|
||||||
REGISTER_EXPANDER_BPROP_IMPL(IOU);
|
REGISTER_EXPANDER_BPROP_IMPL(IOU);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void RegSequenceBpropExpanderOps() { REGISTER_EXPANDER_BPROP_IMPL(make_range); }
|
||||||
|
|
||||||
void RegBpropExpanderOps() {
|
void RegBpropExpanderOps() {
|
||||||
RegMathBpropExpanderOps1();
|
RegMathBpropExpanderOps1();
|
||||||
RegMathBpropExpanderOps2();
|
RegMathBpropExpanderOps2();
|
||||||
|
@ -415,6 +417,7 @@ void RegBpropExpanderOps() {
|
||||||
RegClipBpropExpanderOps();
|
RegClipBpropExpanderOps();
|
||||||
RegInnerBpropExpanderOps();
|
RegInnerBpropExpanderOps();
|
||||||
RegOtherBpropExpanderOps();
|
RegOtherBpropExpanderOps();
|
||||||
|
RegSequenceBpropExpanderOps();
|
||||||
}
|
}
|
||||||
} // namespace graph_bprop
|
} // namespace graph_bprop
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "pipeline/pynative/grad/bprop_expander/bprop_irbuilder.h"
|
||||||
|
#include "pipeline/pynative/grad/bprop_expander/grad_ops/common_utils.h"
|
||||||
|
#include "include/common/utils/utils.h"
|
||||||
|
|
||||||
|
namespace mindspore::expander::bprop {
|
||||||
|
REG_BPROP_BUILDERS_BEGIN(GradSequenceOps)
|
||||||
|
REG_BPROP_BUILDER("make_range").SetBody(BODYFUNC(ib) {
|
||||||
|
auto x = ib->GetInputs();
|
||||||
|
auto id_type = ib->GetDtypeId(ib->GetInput(kIndex0));
|
||||||
|
if (id_type == TypeId::kNumberTypeInt32) {
|
||||||
|
if (x.size() == 1) {
|
||||||
|
return {ib->Value(0)};
|
||||||
|
} else if (x.size() == 2) {
|
||||||
|
return {ib->Value(0), ib->Value(0)};
|
||||||
|
} else {
|
||||||
|
return {ib->Value(0), ib->Value(0), ib->Value(0)};
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (x.size() == 1) {
|
||||||
|
return {ib->Value<int64_t>(0)};
|
||||||
|
} else if (x.size() == 2) {
|
||||||
|
return {ib->Value<int64_t>(0), ib->Value<int64_t>(0)};
|
||||||
|
} else {
|
||||||
|
return {ib->Value<int64_t>(0), ib->Value<int64_t>(0), ib->Value<int64_t>(0)};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
REG_BPROP_BUILDERS_END
|
||||||
|
} // namespace mindspore::expander::bprop
|
|
@ -54,8 +54,8 @@ bool CheckMakeRangeInput(const std::vector<AbstractBasePtr> &input_args, const s
|
||||||
auto element = input_args[i];
|
auto element = input_args[i];
|
||||||
MS_EXCEPTION_IF_NULL(element);
|
MS_EXCEPTION_IF_NULL(element);
|
||||||
auto element_type = element->BuildType();
|
auto element_type = element->BuildType();
|
||||||
if (element_type->type_id() != kInt64->type_id()) {
|
if (element_type->type_id() != kInt64->type_id() && element_type->type_id() != kInt32->type_id()) {
|
||||||
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', the " << i << "th input should be a int64 scalar but got "
|
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', the " << i << "th input should be a int scalar but got "
|
||||||
<< element->ToString();
|
<< element->ToString();
|
||||||
}
|
}
|
||||||
if (!has_variable && element->BuildValue() == kValueAny) {
|
if (!has_variable && element->BuildValue() == kValueAny) {
|
||||||
|
@ -65,7 +65,8 @@ bool CheckMakeRangeInput(const std::vector<AbstractBasePtr> &input_args, const s
|
||||||
return has_variable;
|
return has_variable;
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract::AbstractTuplePtr CalcSlidePara(const std::vector<int64_t> &values, const std::string &prim_name) {
|
abstract::AbstractTuplePtr CalcSlidePara(const std::vector<int64_t> &values, const std::string &prim_name,
|
||||||
|
const TypePtr &type) {
|
||||||
auto values_size = values.size();
|
auto values_size = values.size();
|
||||||
int64_t start = values_size == 1 ? 0LL : values[kIndex0];
|
int64_t start = values_size == 1 ? 0LL : values[kIndex0];
|
||||||
int64_t stop = values_size == 1 ? values[kIndex0] : values[kIndex1];
|
int64_t stop = values_size == 1 ? values[kIndex0] : values[kIndex1];
|
||||||
|
@ -84,7 +85,7 @@ abstract::AbstractTuplePtr CalcSlidePara(const std::vector<int64_t> &values, con
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int64_t i = start; i < stop; i += step) {
|
for (int64_t i = start; i < stop; i += step) {
|
||||||
args.push_back(std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(i)));
|
args.push_back(std::make_shared<abstract::AbstractScalar>(MakeValue(i), type));
|
||||||
if (i > 0 && INT_MAX - i < step) {
|
if (i > 0 && INT_MAX - i < step) {
|
||||||
MS_EXCEPTION(ValueError) << "Integer overflow error occurred when traversing the range. "
|
MS_EXCEPTION(ValueError) << "Integer overflow error occurred when traversing the range. "
|
||||||
<< "Please check the inputs of range.";
|
<< "Please check the inputs of range.";
|
||||||
|
@ -99,7 +100,7 @@ abstract::AbstractTuplePtr CalcSlidePara(const std::vector<int64_t> &values, con
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int64_t i = start; i > stop; i += step) {
|
for (int64_t i = start; i > stop; i += step) {
|
||||||
args.push_back(std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(i)));
|
args.push_back(std::make_shared<abstract::AbstractScalar>(MakeValue(i), type));
|
||||||
if (i < 0 && INT_MIN - i > step) {
|
if (i < 0 && INT_MIN - i > step) {
|
||||||
MS_EXCEPTION(ValueError) << "Integer overflow error occurred when traversing the range. "
|
MS_EXCEPTION(ValueError) << "Integer overflow error occurred when traversing the range. "
|
||||||
<< "Please check the inputs of range.";
|
<< "Please check the inputs of range.";
|
||||||
|
@ -113,9 +114,10 @@ AbstractBasePtr InferImplMakeRange(const PrimitivePtr &primitive, const Abstract
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto prim_name = primitive->name();
|
auto prim_name = primitive->name();
|
||||||
bool has_variable = CheckMakeRangeInput(args_spec_list, prim_name);
|
bool has_variable = CheckMakeRangeInput(args_spec_list, prim_name);
|
||||||
|
auto type = args_spec_list[0]->BuildType();
|
||||||
if (has_variable) {
|
if (has_variable) {
|
||||||
// If the input to make_range has variable input, the output abs should be dynamic length sequence.
|
// If the input to make_range has variable input, the output abs should be dynamic length sequence.
|
||||||
auto element = std::make_shared<abstract::AbstractScalar>(kValueAny, kInt64);
|
auto element = std::make_shared<abstract::AbstractScalar>(kValueAny, type);
|
||||||
auto ret = std::make_shared<abstract::AbstractTuple>(AbstractBasePtrList{element});
|
auto ret = std::make_shared<abstract::AbstractTuple>(AbstractBasePtrList{element});
|
||||||
ret->CheckAndConvertToDynamicLenSequence();
|
ret->CheckAndConvertToDynamicLenSequence();
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -124,13 +126,13 @@ AbstractBasePtr InferImplMakeRange(const PrimitivePtr &primitive, const Abstract
|
||||||
for (size_t i = 0; i < args_spec_list.size(); ++i) {
|
for (size_t i = 0; i < args_spec_list.size(); ++i) {
|
||||||
auto element = args_spec_list[i];
|
auto element = args_spec_list[i];
|
||||||
auto element_val = element->BuildValue();
|
auto element_val = element->BuildValue();
|
||||||
if (!element_val->isa<Int64Imm>()) {
|
if (!element_val->isa<Int64Imm>() && !element_val->isa<Int32Imm>()) {
|
||||||
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', the " << i << "th input should be a int64 scalar but got "
|
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', the " << i << "th input should be a int scalar but got "
|
||||||
<< element->ToString();
|
<< element->ToString();
|
||||||
}
|
}
|
||||||
values.push_back(element_val->cast<Int64ImmPtr>()->value());
|
values.push_back(element_val->cast<Int64ImmPtr>()->value());
|
||||||
}
|
}
|
||||||
return CalcSlidePara(values, prim_name);
|
return CalcSlidePara(values, prim_name, type);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
|
@ -44,15 +44,6 @@ def get_bprop_sequence_len(self):
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(seq.make_range)
|
|
||||||
def get_bprop_range(self):
|
|
||||||
"""Generate bprop for make_range"""
|
|
||||||
def bprop(start, limit, delta, out, dout):
|
|
||||||
return (zeros_like(start), zeros_like(limit), zeros_like(delta))
|
|
||||||
|
|
||||||
return bprop
|
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(seq.SequenceAdd)
|
@bprop_getters.register(seq.SequenceAdd)
|
||||||
def get_bprop_sequence_add(self):
|
def get_bprop_sequence_add(self):
|
||||||
"""Generate bprop for SequenceAdd"""
|
"""Generate bprop for SequenceAdd"""
|
||||||
|
|
|
@ -59,7 +59,7 @@ def test_single_for_2():
|
||||||
y += x
|
y += x
|
||||||
return y
|
return y
|
||||||
|
|
||||||
with pytest.raises(TypeError, match="the 0th input should be a int64 scalar"):
|
with pytest.raises(TypeError, match="the 0th input should be a int scalar"):
|
||||||
control_flow_for()
|
control_flow_for()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,6 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
import pytest
|
import pytest
|
||||||
from mindspore.ops.operations import _sequence_ops as seq
|
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.nn import Cell
|
from mindspore.nn import Cell
|
||||||
from mindspore.common import mutable
|
from mindspore.common import mutable
|
||||||
|
@ -24,13 +23,19 @@ context.set_context(mode=context.GRAPH_MODE)
|
||||||
context_prepare()
|
context_prepare()
|
||||||
|
|
||||||
|
|
||||||
class Net(Cell):
|
class NetRange3(Cell):
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.func = seq.make_range()
|
|
||||||
|
|
||||||
def construct(self, x, y, z):
|
def construct(self, x, y, z):
|
||||||
return self.func(x, y, z)
|
return range(x, y, z)
|
||||||
|
|
||||||
|
|
||||||
|
class NetRange2(Cell):
|
||||||
|
def construct(self, x, y):
|
||||||
|
return range(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class NetRange1(Cell):
|
||||||
|
def construct(self, x):
|
||||||
|
return range(x)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level1
|
@pytest.mark.level1
|
||||||
|
@ -45,14 +50,28 @@ def test_seqence_make_range():
|
||||||
Expectation: the behavior is matched to python style
|
Expectation: the behavior is matched to python style
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def func(x, y, z):
|
def func3(x, y, z):
|
||||||
return tuple(range(x, y, z))
|
return tuple(range(x, y, z))
|
||||||
|
|
||||||
net_ms = Net()
|
def func2(x, y):
|
||||||
input_x = 1
|
return tuple(range(x, y))
|
||||||
|
|
||||||
|
def func1(x):
|
||||||
|
return tuple(range(x))
|
||||||
|
|
||||||
|
input_x = 10
|
||||||
input_y = 1000
|
input_y = 1000
|
||||||
input_z = 31
|
input_z = 31
|
||||||
fact = TupleFactory(net_ms, func, (input_x, input_y, input_z))
|
net_ms = NetRange3()
|
||||||
|
fact = TupleFactory(net_ms, func3, (input_x, input_y, input_z))
|
||||||
|
fact.forward_cmp()
|
||||||
|
|
||||||
|
net_ms = NetRange2()
|
||||||
|
fact = TupleFactory(net_ms, func2, (input_x, input_y))
|
||||||
|
fact.forward_cmp()
|
||||||
|
|
||||||
|
net_ms = NetRange1()
|
||||||
|
fact = TupleFactory(net_ms, func1, (input_x,))
|
||||||
fact.forward_cmp()
|
fact.forward_cmp()
|
||||||
|
|
||||||
|
|
||||||
|
@ -67,16 +86,22 @@ def test_seqence_make_range_grad():
|
||||||
Description: setitem operation on tuple type
|
Description: setitem operation on tuple type
|
||||||
Expectation: the behavior is matched to python style
|
Expectation: the behavior is matched to python style
|
||||||
"""
|
"""
|
||||||
net_ms = Net()
|
|
||||||
input_x = mutable(10)
|
input_x = mutable(10)
|
||||||
input_y = mutable(100)
|
input_y = mutable(100)
|
||||||
input_z = mutable(3)
|
input_z = mutable(3)
|
||||||
dout = mutable((1, 1), True)
|
dout = mutable((1, 1), True)
|
||||||
|
|
||||||
|
net_ms = NetRange3()
|
||||||
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
||||||
print("grad out1 = ", grad_func(input_x, input_y, input_z, dout))
|
grad_out = grad_func(input_x, input_y, input_z, dout)
|
||||||
input_x = 10
|
assert grad_out == (0, 0, 0)
|
||||||
input_y = 100
|
|
||||||
input_z = 30
|
net_ms = NetRange2()
|
||||||
dout = (1, 1, 1)
|
|
||||||
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
||||||
print("grad out1 = ", grad_func(input_x, input_y, input_z, dout))
|
grad_out = grad_func(input_x, input_y, dout)
|
||||||
|
assert grad_out == (0, 0)
|
||||||
|
|
||||||
|
net_ms = NetRange1()
|
||||||
|
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
||||||
|
grad_out = grad_func(input_x, dout)
|
||||||
|
assert grad_out == (0,)
|
||||||
|
|
|
@ -96,7 +96,7 @@ def test_range_with_wrong_input():
|
||||||
|
|
||||||
with pytest.raises(TypeError) as ex:
|
with pytest.raises(TypeError) as ex:
|
||||||
foo()
|
foo()
|
||||||
assert "input should be a int64 scalar but got" in str(ex.value)
|
assert "input should be a int scalar but got" in str(ex.value)
|
||||||
|
|
||||||
|
|
||||||
def test_range_with_wrong_input_2():
|
def test_range_with_wrong_input_2():
|
||||||
|
@ -115,4 +115,4 @@ def test_range_with_wrong_input_2():
|
||||||
|
|
||||||
with pytest.raises(TypeError) as ex:
|
with pytest.raises(TypeError) as ex:
|
||||||
foo()
|
foo()
|
||||||
assert "input should be a int64 scalar but got" in str(ex.value)
|
assert "input should be a int scalar but got" in str(ex.value)
|
||||||
|
|
|
@ -57,7 +57,7 @@ def test_missing_return():
|
||||||
z = Tensor(2, mstype.int32)
|
z = Tensor(2, mstype.int32)
|
||||||
with pytest.raises(TypeError) as er:
|
with pytest.raises(TypeError) as er:
|
||||||
net(x, y, z)
|
net(x, y, z)
|
||||||
assert "For 'make_range', the 0th input should be a int64 scalar" in str(er.value)
|
assert "For 'make_range', the 0th input should be a int scalar" in str(er.value)
|
||||||
|
|
||||||
|
|
||||||
def test_nest_function_missing_return():
|
def test_nest_function_missing_return():
|
||||||
|
@ -89,7 +89,7 @@ def test_nest_function_missing_return():
|
||||||
z = Tensor(2, mstype.int32)
|
z = Tensor(2, mstype.int32)
|
||||||
with pytest.raises(TypeError) as er:
|
with pytest.raises(TypeError) as er:
|
||||||
net(x, y, z)
|
net(x, y, z)
|
||||||
assert "For 'make_range', the 0th input should be a int64 scalar" in str(er.value)
|
assert "For 'make_range', the 0th input should be a int scalar" in str(er.value)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason='Case will not appear for now, but may appear in the future')
|
@pytest.mark.skip(reason='Case will not appear for now, but may appear in the future')
|
||||||
|
|
Loading…
Reference in New Issue