forked from mindspore-Ecosystem/mindspore
!12543 add StridedSliceGrad op
From: @lyvette Reviewed-by: Signed-off-by:
This commit is contained in:
commit
eeb7291d51
|
@ -0,0 +1,84 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/grad/strided_slice_grad.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <bitset>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void StridedSliceGrad::Init(const int64_t begin_mask, const int64_t end_mask, const int64_t ellipsis_mask,
|
||||
const int64_t new_axis_mask, const int64_t shrink_axis_mask) {
|
||||
this->set_begin_mask(begin_mask);
|
||||
this->set_end_mask(end_mask);
|
||||
this->set_ellipsis_mask(ellipsis_mask);
|
||||
this->set_new_axis_mask(new_axis_mask);
|
||||
this->set_shrink_axis_mask(shrink_axis_mask);
|
||||
}
|
||||
|
||||
void StridedSliceGrad::set_begin_mask(const int64_t begin_mask) {
|
||||
CheckAndConvertUtils::CheckInteger(kBeginMask, begin_mask, kGreaterEqual, 0, this->name());
|
||||
this->AddAttr(kBeginMask, MakeValue(begin_mask));
|
||||
}
|
||||
int64_t StridedSliceGrad::get_begin_mask() const {
|
||||
auto value_ptr = GetAttr(kBeginMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void StridedSliceGrad::set_end_mask(const int64_t end_mask) {
|
||||
CheckAndConvertUtils::CheckInteger(kEndMask, end_mask, kGreaterEqual, 0, this->name());
|
||||
this->AddAttr(kEndMask, MakeValue(end_mask));
|
||||
}
|
||||
int64_t StridedSliceGrad::get_end_mask() const {
|
||||
auto value_ptr = GetAttr(kEndMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void StridedSliceGrad::set_ellipsis_mask(const int64_t ellipsis_mask) {
|
||||
CheckAndConvertUtils::CheckInteger(kEllipsisMask, ellipsis_mask, kGreaterEqual, 0, this->name());
|
||||
std::bitset<sizeof(int64_t) * 8> bs(ellipsis_mask);
|
||||
std::ostringstream buffer;
|
||||
if (bs.count() > 1) {
|
||||
buffer << "For" << this->name() << ", only support one ellipsis in the index, but got " << this->get_end_mask();
|
||||
MS_EXCEPTION(ValueError) << buffer.str();
|
||||
}
|
||||
this->AddAttr(kEllipsisMask, MakeValue(ellipsis_mask));
|
||||
}
|
||||
int64_t StridedSliceGrad::get_ellipsis_mask() const {
|
||||
auto value_ptr = GetAttr(kEllipsisMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void StridedSliceGrad::set_new_axis_mask(const int64_t new_axis_mask) {
|
||||
CheckAndConvertUtils::CheckInteger(kNewAxisMask, new_axis_mask, kGreaterEqual, 0, this->name());
|
||||
this->AddAttr(kNewAxisMask, MakeValue(new_axis_mask));
|
||||
}
|
||||
int64_t StridedSliceGrad::get_new_axis_mask() const {
|
||||
auto value_ptr = GetAttr(kNewAxisMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void StridedSliceGrad::set_shrink_axis_mask(const int64_t shrink_axis_mask) {
|
||||
CheckAndConvertUtils::CheckInteger(kShrinkAxisMask, shrink_axis_mask, kGreaterEqual, 0, this->name());
|
||||
this->AddAttr(kShrinkAxisMask, MakeValue(shrink_axis_mask));
|
||||
}
|
||||
int64_t StridedSliceGrad::get_shrink_axis_mask() const {
|
||||
auto value_ptr = GetAttr(kShrinkAxisMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_C(kNameStridedSliceGrad, StridedSliceGrad);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_STRIDED_SLICE_GRAD_H_
|
||||
#define MINDSPORE_CORE_OPS_STRIDED_SLICE_GRAD_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameStridedSliceGrad = "StridedSliceGrad";
|
||||
class StridedSliceGrad : public PrimitiveC {
|
||||
public:
|
||||
StridedSliceGrad() : PrimitiveC(kNameStridedSliceGrad) {}
|
||||
~StridedSliceGrad() = default;
|
||||
MS_DECLARE_PARENT(StridedSliceGrad, PrimitiveC);
|
||||
void Init(const int64_t begin_mask = 0, const int64_t end_mask = 0, const int64_t ellipsis_mask = 0,
|
||||
const int64_t new_axis_mask = 0, const int64_t shrink_axis_mask = 0);
|
||||
void set_begin_mask(const int64_t begin_mask);
|
||||
void set_end_mask(const int64_t end_mask);
|
||||
void set_ellipsis_mask(const int64_t ellipsis_mask);
|
||||
void set_new_axis_mask(const int64_t new_axis_mask);
|
||||
void set_shrink_axis_mask(const int64_t shrink_axis_mask);
|
||||
int64_t get_begin_mask() const;
|
||||
int64_t get_end_mask() const;
|
||||
int64_t get_ellipsis_mask() const;
|
||||
int64_t get_new_axis_mask() const;
|
||||
int64_t get_shrink_axis_mask() const;
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_STRIDED_SLICE_GRAD_H_
|
Loading…
Reference in New Issue