forked from mindspore-Ecosystem/mindspore
support cpu dynamic_broadcastto
This commit is contained in:
parent
ac39b6ebca
commit
aeb74b0d64
|
@ -20,7 +20,6 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kBroadcastToInputsNum = 1;
|
||||
constexpr size_t kBroadcastToOutputsNum = 1;
|
||||
} // namespace
|
||||
|
||||
|
@ -32,6 +31,21 @@ void BroadcastToCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
size_t input_shape_size = input_shape_.size();
|
||||
size_t output_shape_size = output_shape_.size();
|
||||
|
||||
for (size_t i = 0; i < input_shape_size; ++i) {
|
||||
shape_info_.input_shape_[i] = SizeToInt(input_shape_[i]);
|
||||
}
|
||||
for (size_t i = 0; i < output_shape_size; ++i) {
|
||||
shape_info_.output_shape_[i] = SizeToInt(output_shape_[i]);
|
||||
}
|
||||
shape_info_.input_shape_size_ = SizeToInt(input_shape_size);
|
||||
shape_info_.output_shape_size_ = SizeToInt(output_shape_size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BroadcastToCPUKernel<T>::CheckArgs() {
|
||||
size_t input_shape_size = input_shape_.size();
|
||||
size_t output_shape_size = output_shape_.size();
|
||||
if (output_shape_size < input_shape_size) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', input tensor 'input_x' and target shape 'shape' can't "
|
||||
|
@ -56,22 +70,13 @@ void BroadcastToCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
<< Vector2Str(input_shape_) << ", and the dimension of target shape 'shape': " << Vector2Str(output_shape_);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < input_shape_size; ++i) {
|
||||
shape_info_.input_shape_[i] = SizeToInt(input_shape_[i]);
|
||||
}
|
||||
for (size_t i = 0; i < output_shape_size; ++i) {
|
||||
shape_info_.output_shape_[i] = SizeToInt(output_shape_[i]);
|
||||
}
|
||||
shape_info_.input_shape_size_ = SizeToInt(input_shape_size);
|
||||
shape_info_.output_shape_size_ = SizeToInt(output_shape_size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool BroadcastToCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kBroadcastToInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kBroadcastToOutputsNum, kernel_name_);
|
||||
CheckArgs();
|
||||
const auto *input_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
int status = static_cast<int>(NNACL_OK);
|
||||
|
|
|
@ -36,6 +36,8 @@ class BroadcastToCPUKernel : public CPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
void CheckArgs();
|
||||
|
||||
private:
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
|
@ -48,6 +50,18 @@ MS_REG_CPU_KERNEL_T(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeInt32).Add
|
|||
BroadcastToCPUKernel, int);
|
||||
MS_REG_CPU_KERNEL_T(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastToCPUKernel, bool);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
BroadcastToCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
BroadcastToCPUKernel, int);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastToCPUKernel, bool);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
Loading…
Reference in New Issue