forked from mindspore-Ecosystem/mindspore
add dynamic broadcastgradientargs
This commit is contained in:
parent
8361f34d78
commit
f1b59bc454
|
@ -0,0 +1,206 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/host/dynamic_broadcast_gradient_args_kernel.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
const int kInputNum = 2;
|
||||
|
||||
std::vector<std::vector<int64_t>> GetGradientIndices(const std::vector<std::vector<int64_t>> &reverse_shape,
|
||||
const size_t largest_rank) {
|
||||
std::vector<std::vector<int64_t>> grad_reduce_idx(kInputNum);
|
||||
// indices of j-th component of each input.
|
||||
bool prev_is_one[kInputNum];
|
||||
bool current_is_one[kInputNum];
|
||||
for (int i = 0; i < kInputNum; ++i) {
|
||||
prev_is_one[i] = false;
|
||||
current_is_one[i] = false;
|
||||
}
|
||||
|
||||
bool set_one = false;
|
||||
for (size_t j = 0; j < largest_rank; ++j) {
|
||||
int output_dim = -1;
|
||||
bool output_dim_set = false;
|
||||
bool none_is_one = true;
|
||||
// Find which indices are 1.
|
||||
for (int i = 0; i < kInputNum; ++i) {
|
||||
if (reverse_shape[i][j] == 1) {
|
||||
current_is_one[i] = true;
|
||||
none_is_one = false;
|
||||
} else {
|
||||
current_is_one[i] = false;
|
||||
if (!output_dim_set || reverse_shape[i][j] == static_cast<int64_t>(output_dim)) {
|
||||
output_dim = reverse_shape[i][j];
|
||||
output_dim_set = true;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Input[0] and input[1] Cannot broadcast!";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// All dimensions are 1.
|
||||
if (!output_dim_set) {
|
||||
for (int i = 0; i < kInputNum; ++i) {
|
||||
grad_reduce_idx[i].push_back(largest_rank - 1 - j);
|
||||
}
|
||||
continue;
|
||||
} else if (std::equal(current_is_one, current_is_one + kInputNum, prev_is_one) && set_one) {
|
||||
for (int i = 0; i < kInputNum; ++i) {
|
||||
if (current_is_one[i] && !none_is_one) {
|
||||
grad_reduce_idx[i].push_back(largest_rank - 1 - j);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < kInputNum; ++i) {
|
||||
if (current_is_one[i] && !none_is_one) {
|
||||
grad_reduce_idx[i].push_back(largest_rank - 1 - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
set_one = true;
|
||||
for (int i = 0; i < kInputNum; ++i) {
|
||||
prev_is_one[i] = current_is_one[i];
|
||||
}
|
||||
}
|
||||
return grad_reduce_idx;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> CalculateOutput(const std::vector<std::vector<int64_t>> &x) {
|
||||
std::vector<std::vector<int64_t>> grad_reduce_idx(kInputNum);
|
||||
bool all_equal = true;
|
||||
size_t largest_rank = 0;
|
||||
for (int i = 0; i < kInputNum; ++i) {
|
||||
if (x[i] != x[0]) {
|
||||
all_equal = false;
|
||||
}
|
||||
if (x[i].size() > largest_rank) {
|
||||
largest_rank = x[i].size();
|
||||
}
|
||||
}
|
||||
if (all_equal) {
|
||||
return grad_reduce_idx;
|
||||
}
|
||||
|
||||
// Reverse input the shapes
|
||||
std::vector<std::vector<int64_t>> reverse_shape(kInputNum);
|
||||
for (int i = 0; i < kInputNum; ++i) {
|
||||
reverse_shape[i] = x[i];
|
||||
std::reverse(reverse_shape[i].begin(), reverse_shape[i].end());
|
||||
}
|
||||
|
||||
// 1-extend and align all vectors.
|
||||
for (int i = 0; i < kInputNum; ++i) {
|
||||
if (reverse_shape[i].size() < largest_rank) {
|
||||
reverse_shape[i].resize(largest_rank, 1);
|
||||
}
|
||||
}
|
||||
grad_reduce_idx = GetGradientIndices(reverse_shape, largest_rank);
|
||||
return grad_reduce_idx;
|
||||
}
|
||||
|
||||
std::vector<int64_t> GetInputShape(const CNodePtr &cnode, size_t index) {
|
||||
auto address_x = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode, index);
|
||||
auto shape_x = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
|
||||
auto type_x = AnfAlgo::GetOutputInferDataType(cnode, index);
|
||||
if (type_x != TypeId::kNumberTypeInt64) {
|
||||
MS_LOG(EXCEPTION) << "Input x type must be int64, but :" << type_x;
|
||||
}
|
||||
if (shape_x.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Input" << index << " must be [1-D], but " << shape_x.size() << "-D.";
|
||||
}
|
||||
|
||||
size_t x_num = shape_x[0];
|
||||
std::vector<int64_t> x{SizeToLong(x_num)};
|
||||
|
||||
auto x_shape_value = std::make_shared<tensor::Tensor>(type_x, x);
|
||||
x_shape_value->set_device_address(address_x);
|
||||
x_shape_value->data_sync();
|
||||
|
||||
auto x_value = reinterpret_cast<int64_t *>(x_shape_value->data_c());
|
||||
MS_EXCEPTION_IF_NULL(x_value);
|
||||
std::vector<int64_t> input_shape = {x_value, x_value + x_num};
|
||||
return input_shape;
|
||||
}
|
||||
|
||||
size_t SetOutputValue(const CNodePtr &cnode, const std::vector<std::vector<int64_t>> &grad_reduce_idx, size_t index,
|
||||
size_t input_num) {
|
||||
std::vector<int64_t> output;
|
||||
size_t idx_num = grad_reduce_idx[index].size();
|
||||
|
||||
for (size_t k = 0; k < idx_num; ++k) {
|
||||
output.push_back(grad_reduce_idx[index][idx_num - 1 - k]);
|
||||
}
|
||||
|
||||
auto out_addr = AnfAlgo::GetOutputAddr(cnode, index);
|
||||
MS_EXCEPTION_IF_NULL(out_addr);
|
||||
|
||||
size_t out_size = idx_num;
|
||||
if (idx_num == 0) {
|
||||
out_size = input_num;
|
||||
for (size_t k = 0; k < input_num; ++k) {
|
||||
output.push_back(k);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> out_shape{SizeToLong(out_size)};
|
||||
auto output_type = TypeId::kNumberTypeInt64;
|
||||
auto tensor_for_sync = std::make_shared<tensor::Tensor>(output_type, out_shape);
|
||||
|
||||
auto data_ptr = static_cast<int64_t *>(tensor_for_sync->data_c());
|
||||
for (size_t i = 0; i < out_size; ++i) {
|
||||
MS_LOG(DEBUG) << "DEBUG r" << index << "_output_shape[" << i << "]:" << output[i];
|
||||
*(data_ptr + i) = output[i];
|
||||
}
|
||||
|
||||
out_addr->SyncHostToDevice(out_shape, LongToSize(tensor_for_sync->data().nbytes()), tensor_for_sync->data_type(),
|
||||
tensor_for_sync->data_c(), tensor_for_sync->device_info().host_format_);
|
||||
return out_size;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void DynamicBroadcastGradientArgsKernel::Execute() {
|
||||
MS_LOG(INFO) << "Execute DynamicBroadcastGradientArgsKernel Start";
|
||||
auto cnode = cnode_ptr_.lock();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto input_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
if (input_num != 2) {
|
||||
MS_LOG(EXCEPTION) << "Invalid Input Num:" << input_num;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> input_shapes(kInputNum);
|
||||
input_shapes[0] = GetInputShape(cnode, 0);
|
||||
input_shapes[1] = GetInputShape(cnode, 1);
|
||||
auto grad_reduce_idx = CalculateOutput(input_shapes);
|
||||
|
||||
auto r0_size = SetOutputValue(cnode, grad_reduce_idx, 0, input_shapes[0].size());
|
||||
auto r1_size = SetOutputValue(cnode, grad_reduce_idx, 1, input_shapes[1].size());
|
||||
|
||||
std::vector<size_t> r0_shp{r0_size};
|
||||
std::vector<size_t> r1_shp{r1_size};
|
||||
auto output_type = TypeId::kNumberTypeInt64;
|
||||
AnfAlgo::SetOutputInferTypeAndShape({output_type, output_type}, {r0_shp, r1_shp}, cnode.get());
|
||||
MS_LOG(INFO) << "Execute DynamicBroadcastGradientArgsKernel End";
|
||||
}
|
||||
|
||||
device::DynamicKernelPtr DynamicBroadcastGradientArgsKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr,
|
||||
void *stream_ptr) {
|
||||
return std::make_shared<DynamicBroadcastGradientArgsKernel>(stream_ptr, cnode_ptr);
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_BROADCAST_GRADIENT_ARGS_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_BROADCAST_GRADIENT_ARGS_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "runtime/device/ascend/executor/host_dynamic_kernel.h"
|
||||
#include "backend/kernel_compiler/host/host_kernel_mod.h"
|
||||
|
||||
using HostDynamicKernel = mindspore::device::ascend::HostDynamicKernel;
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class DynamicBroadcastGradientArgsKernel : public HostDynamicKernel {
|
||||
public:
|
||||
DynamicBroadcastGradientArgsKernel(void *stream, const CNodePtr &cnode_ptr) : HostDynamicKernel(stream, cnode_ptr) {}
|
||||
~DynamicBroadcastGradientArgsKernel() override = default;
|
||||
void Execute() override;
|
||||
};
|
||||
|
||||
class DynamicBroadcastGradientArgsKernelMod : public HostKernelMod {
|
||||
public:
|
||||
DynamicBroadcastGradientArgsKernelMod() = default;
|
||||
~DynamicBroadcastGradientArgsKernelMod() override = default;
|
||||
device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override;
|
||||
};
|
||||
MS_HOST_REG_KERNEL(DynamicBroadcastGradientArgs, DynamicBroadcastGradientArgsKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_BROADCAST_GRADIENT_ARGS_KERNEL_H_
|
|
@ -29,11 +29,6 @@ void HostMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
|
|||
MS_LOG(INFO) << "HostMetadataInfo.";
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
||||
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (op_name != kDynamicShape) {
|
||||
MS_LOG(DEBUG) << "Host does not have op [" << op_name << "]";
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<std::string> inputs_format{};
|
||||
std::vector<TypeId> inputs_type{};
|
||||
|
|
|
@ -74,6 +74,7 @@ constexpr auto kFastGeLU = "FastGeLU";
|
|||
constexpr auto kFastGeLUGrad = "FastGeLUGrad";
|
||||
constexpr auto kZerosLike = "ZerosLike";
|
||||
constexpr auto kOnesLike = "OnesLike";
|
||||
constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs";
|
||||
|
||||
// NN
|
||||
constexpr auto kCTCLoss = "CTCLoss";
|
||||
|
@ -632,6 +633,8 @@ inline const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("strin
|
|||
inline const PrimitivePtr kPrimDictLen = std::make_shared<Primitive>("dict_len");
|
||||
inline const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
|
||||
inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs");
|
||||
inline const PrimitivePtr kPrimDynamicBroadcastGradientArgs =
|
||||
std::make_shared<Primitive>(kDynamicBroadcastGradientArgs);
|
||||
|
||||
class DoSignaturePrimitive : public Primitive {
|
||||
public:
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/dynamic_broadcast_gradient_args.h"
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
size_t CheckInputsAndGetShape(const AbstractBasePtr &input_arg, const string &prim_name) {
|
||||
MS_EXCEPTION_IF_NULL(input_arg);
|
||||
if (input_arg->isa<abstract::AbstractTensor>()) {
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_arg->BuildShape())[kShape];
|
||||
auto input_size = input_shape.size();
|
||||
if (input_size != 1) {
|
||||
MS_EXCEPTION(TypeError) << prim_name << " input must be 1-D, but dims is " << input_size;
|
||||
}
|
||||
if (input_shape[0] == abstract::Shape::SHP_ANY) {
|
||||
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_arg->BuildShape())[kMaxShape];
|
||||
if (max_shape.empty()) {
|
||||
MS_LOG(EXCEPTION) << prim_name << " input shape is dynamic, but max shape is empty.";
|
||||
}
|
||||
return max_shape[0];
|
||||
}
|
||||
return input_shape[0];
|
||||
} else if (input_arg->isa<abstract::AbstractTuple>()) {
|
||||
auto x_shape = dyn_cast<abstract::AbstractTuple>(input_arg);
|
||||
auto x_shape_data = x_shape->elements();
|
||||
return x_shape_data.size();
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << prim_name << " input must be a tuple or Tensor.";
|
||||
}
|
||||
}
|
||||
|
||||
abstract::TupleShapePtr Infer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
auto x_shape = CheckInputsAndGetShape(input_args[0], prim_name);
|
||||
auto y_shape = CheckInputsAndGetShape(input_args[1], prim_name);
|
||||
|
||||
ShapeVector shape{abstract::Shape::SHP_ANY};
|
||||
ShapeVector min_shape{1L};
|
||||
size_t max_size = x_shape > y_shape ? x_shape : y_shape;
|
||||
ShapeVector max_shape{SizeToLong(max_size)};
|
||||
auto out_shape = std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{out_shape, out_shape});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr DynamicBroadcastGradientArgsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto types = std::vector<TypePtr>{kInt64, kInt64};
|
||||
auto output_type = std::make_shared<Tuple>(types);
|
||||
return abstract::MakeAbstract(Infer(primitive, input_args), output_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(DynamicBroadcastGradientArgs, prim::kPrimDynamicBroadcastGradientArgs,
|
||||
DynamicBroadcastGradientArgsInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_DYNAMIC_BROADCAST_GRADIENT_ARGS_H_
|
||||
#define MINDSPORE_CORE_OPS_DYNAMIC_BROADCAST_GRADIENT_ARGS_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
class DynamicBroadcastGradientArgs : public PrimitiveC {
|
||||
public:
|
||||
DynamicBroadcastGradientArgs() : PrimitiveC(prim::kPrimDynamicBroadcastGradientArgs->name()) {}
|
||||
~DynamicBroadcastGradientArgs() = default;
|
||||
MS_DECLARE_PARENT(DynamicBroadcastGradientArgs, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
AbstractBasePtr DynamicBroadcastGradientArgsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimDynamicBroadcastGradientArgsPtr = std::shared_ptr<DynamicBroadcastGradientArgs>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_DYNAMIC_BROADCAST_GRADIENT_ARGS_H_
|
|
@ -1,40 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""DynamicShape op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
dynamic_shape_op_info = AiCPURegOp("DynamicShape") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(dynamic_shape_op_info)
|
||||
def _dynamic_shape_aicpu():
|
||||
"""Unique AiCPU register"""
|
||||
return
|
|
@ -21,7 +21,7 @@ from ..._checkparam import Rel
|
|||
from ..._checkparam import Validator as validator
|
||||
from ... import context
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register
|
||||
from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive
|
||||
from ..operations.math_ops import _infer_shape_reduce
|
||||
from ...communication.management import GlobalComm
|
||||
from .. import signature as sig
|
||||
|
@ -1111,3 +1111,45 @@ class DynamicStitch(PrimitiveWithCheck):
|
|||
mstype.number_type + (mstype.bool_,), self.name)
|
||||
validator.check(f"type of data[{i}]", data_type[i], f"type of data[0]", data_type[0], Rel.EQ, self.name)
|
||||
return data_type[0]
|
||||
|
||||
|
||||
class DynamicBroadcastGradientArgs(Primitive):
|
||||
"""
|
||||
Broadcast the two input shapes, return the dimensions that each need to be broadcast.
|
||||
|
||||
Input shape `s0` and shape `s1` can be broadcast to a common shape if for each dimension pair they are either equal
|
||||
or input is one or the target dimension is -1. In case of -1 in target shape, it will be replaced by the input
|
||||
shape's value in that dimension.
|
||||
|
||||
Inputs:
|
||||
- **s0** (Tensor) - A `1-D` tensor. The data type should be one of the following types: int32, int64,
|
||||
uint32, uint64.
|
||||
- **s1** (Tensor) - A `1-D` tensor with the same type as `s0`.
|
||||
|
||||
Outputs:
|
||||
Tuple(Tensor), tuple of 2 tensors, r0 and r1. The first one is the index tensor and the other one is the mask
|
||||
tensor.
|
||||
|
||||
- **r0** (Tensor) - The output shape is 1-D with the same type as s0.
|
||||
- **r1** (Tensor) - The output shape is 1-D with the same type as s0.
|
||||
|
||||
Raises:
|
||||
ValueError: if the `s0` and `s1` are incompatible, or if a - 1 in the target shape is in an invalid
|
||||
location.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> shape0 = (4, 2, 1)
|
||||
>>> shape1 = (2, 7)
|
||||
>>> from mindspore.ops.operations import _inner_ops
|
||||
>>> args = _inner_ops.DynamicBroadcastGradientArgs()
|
||||
>>> r0, r1 = args(Tensor(shape0), Tensor(shape1))
|
||||
>>> print(r0, r1)
|
||||
[2], [0]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Init BroadcastGradientArgs"""
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
|
||||
from mindspore.ops.operations import _inner_ops
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.args = _inner_ops.BroadcastGradientArgs()
|
||||
|
||||
def construct(self, s0, s1):
|
||||
return self.args(s0, s1)
|
||||
|
||||
|
||||
def test_net():
|
||||
shape0 = (4, 2, 1)
|
||||
shape1 = (2, 7)
|
||||
net = Net()
|
||||
r0, r1 = net(shape0, shape1)
|
||||
print(r0, r1)
|
||||
r0_expected = [2]
|
||||
r1_expected = [0]
|
||||
|
||||
assert np.array_equal(r0_expected, r0.asnumpy())
|
||||
assert np.array_equal(r1_expected, r1.asnumpy())
|
Loading…
Reference in New Issue