forked from mindspore-Ecosystem/mindspore
!32780 Refactor some CPU operators for using new KernelMod class.
Merge pull request !32780 from yangshuo/kernel_mod_cpu
This commit is contained in:
commit
76ab8ed102
|
@ -13,11 +13,9 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/argmax_cpu_kernel.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "mindspore/core/ops/arg_max.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -27,28 +25,28 @@ constexpr size_t kArgMaxInputsNum = 1;
|
|||
constexpr size_t kArgMaxOutputsNum = 1;
|
||||
constexpr char kKernelName[] = "ArgMax";
|
||||
|
||||
size_t get_element_num(const std::vector<size_t> &shape) {
|
||||
size_t size = 1;
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
int64_t get_element_num(const std::vector<int64_t> &shape) {
|
||||
int64_t size = 1;
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(shape.size()); i++) {
|
||||
size *= shape[i];
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool check_validation(const std::vector<size_t> &shape, const size_t num_before_axis, const size_t num_after_axis,
|
||||
bool check_validation(const std::vector<int64_t> &shape, const int64_t num_before_axis, const int64_t num_after_axis,
|
||||
const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kArgMaxInputsNum, kKernelName);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kArgMaxOutputsNum, kKernelName);
|
||||
size_t data_size = sizeof(T);
|
||||
size_t input_size = get_element_num(shape) * data_size;
|
||||
size_t output_num = num_before_axis * num_after_axis;
|
||||
size_t output_size = output_num * sizeof(int);
|
||||
if (inputs[0]->size != input_size) {
|
||||
auto data_size = sizeof(T);
|
||||
int64_t input_size = get_element_num(shape) * static_cast<int64_t>(data_size);
|
||||
int64_t output_num = num_before_axis * num_after_axis;
|
||||
int64_t output_size = output_num * static_cast<int64_t>(sizeof(int));
|
||||
if (static_cast<int64_t>(inputs[0]->size) != input_size) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kKernelName << "', the memory size of 'input_x' should be equal to " << input_size
|
||||
<< ", but got the memory size is " << inputs[0]->size;
|
||||
}
|
||||
if (outputs[0]->size != output_size) {
|
||||
if (static_cast<int64_t>(outputs[0]->size) != output_size) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kKernelName << "', the memory size of output should be equal to " << output_size
|
||||
<< ", but got the memory size is " << outputs[0]->size;
|
||||
}
|
||||
|
@ -68,36 +66,48 @@ bool ArgmaxCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inp
|
|||
auto *output = reinterpret_cast<int32_t *>(outputs[0]->addr);
|
||||
|
||||
std::vector<float> array_axis(dim_axis_);
|
||||
for (size_t i = 0; i < num_before_axis_; i++) {
|
||||
size_t src_index_i = i * dim_axis_ * num_after_axis_;
|
||||
for (size_t j = 0; j < num_after_axis_; j++) {
|
||||
size_t src_index_j = src_index_i + j;
|
||||
for (size_t k = 0; k < dim_axis_; k++) {
|
||||
size_t src_index_k = k * num_after_axis_ + src_index_j;
|
||||
for (int64_t i = 0; i < num_before_axis_; i++) {
|
||||
int64_t src_index_i = i * dim_axis_ * num_after_axis_;
|
||||
for (int64_t j = 0; j < num_after_axis_; j++) {
|
||||
int64_t src_index_j = src_index_i + j;
|
||||
for (int64_t k = 0; k < dim_axis_; k++) {
|
||||
int64_t src_index_k = k * num_after_axis_ + src_index_j;
|
||||
array_axis[k] = static_cast<float>(input[src_index_k]);
|
||||
}
|
||||
auto max_ops = std::max_element(array_axis.begin(), array_axis.end());
|
||||
auto max_index = static_cast<int32_t>(std::distance(array_axis.begin(), max_ops));
|
||||
auto dst_index = i * num_after_axis_ + j;
|
||||
auto dst_index = static_cast<size_t>(i * num_after_axis_ + j);
|
||||
output[dst_index] = max_index;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void ArgmaxCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
bool ArgmaxCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::ArgMax>(base_operator);
|
||||
if (!kernel_ptr) {
|
||||
MS_LOG(ERROR) << "cast ArgMax ops failed!";
|
||||
return false;
|
||||
}
|
||||
if (inputs.size() < 1) {
|
||||
MS_LOG(ERROR) << "Argmax input size should not less than 1!";
|
||||
return false;
|
||||
}
|
||||
workspace_size_list_.clear();
|
||||
InitInputOutputSize(inputs, outputs);
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
shape_ = inputs[0]->GetShapeVector();
|
||||
size_t shape_len = shape_.size();
|
||||
if (shape_len == 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'input_x' should be at least 1, but got 0.";
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'input_x' should be at least 1, but got 0.";
|
||||
return false;
|
||||
}
|
||||
int64_t axis = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
|
||||
int64_t axis = kernel_ptr->get_axis();
|
||||
axis += SizeToLong(shape_len);
|
||||
if (axis < 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' should be in range [-1, " << (shape_len - 1)
|
||||
<< "], but got " << axis;
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'axis' should be in range [-1, " << (shape_len - 1)
|
||||
<< "], but got " << axis;
|
||||
return false;
|
||||
}
|
||||
axis = axis % SizeToLong(shape_len);
|
||||
num_before_axis_ = 1;
|
||||
|
@ -111,11 +121,7 @@ void ArgmaxCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
}
|
||||
dim_axis_ = shape_[LongToSize(axis)];
|
||||
|
||||
auto build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
|
||||
if (build_info->GetInputNum() < 1) {
|
||||
MS_LOG(EXCEPTION) << "Argmax input size should not less than 1!";
|
||||
}
|
||||
auto input_type_id = build_info->GetInputDeviceType(0);
|
||||
auto input_type_id = inputs[0]->GetDtype();
|
||||
switch (input_type_id) {
|
||||
case kNumberTypeFloat32:
|
||||
kernel_func_ = &ArgmaxCpuKernelMod::LaunchKernel<float>;
|
||||
|
@ -124,8 +130,10 @@ void ArgmaxCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
kernel_func_ = &ArgmaxCpuKernelMod::LaunchKernel<float16>;
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Argmax kernel does not support " << TypeIdToString(input_type_id);
|
||||
MS_LOG(ERROR) << "Argmax kernel does not support " << TypeIdToString(input_type_id);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Argmax, ArgmaxCpuKernelMod);
|
||||
|
|
|
@ -25,12 +25,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class ArgmaxCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
class ArgmaxCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
ArgmaxCpuKernelMod() = default;
|
||||
~ArgmaxCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
|
@ -46,10 +47,10 @@ class ArgmaxCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
ArgmaxFunc kernel_func_;
|
||||
|
||||
std::vector<size_t> shape_;
|
||||
size_t num_before_axis_{0};
|
||||
size_t num_after_axis_{0};
|
||||
size_t dim_axis_{0};
|
||||
std::vector<int64_t> shape_;
|
||||
int64_t num_before_axis_{0};
|
||||
int64_t num_after_axis_{0};
|
||||
int64_t dim_axis_{0};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -71,17 +71,19 @@ std::vector<KernelAttr> NativeCpuKernelMod::GetAllSupportedList(const std::strin
|
|||
std::vector<KernelAttr> NativeCpuKernelMod::GetSupportFromOpLib(const std::string &kernel_name) {
|
||||
static std::set<std::string> same_op_name = {"Concat", "Pack", "Stack", "Split", "Transpose",
|
||||
"Unpack", "AddN", "ConcatOffset", "DynamicStitch"};
|
||||
std::vector<KernelAttr> support_kernel_attrs;
|
||||
auto op_info = mindspore::kernel::OpLib::FindOp(kernel_name, kernel::OpImplyType::kCPU);
|
||||
if (op_info == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Not find op[" << kernel_name << "] in cpu. For more details, "
|
||||
<< "please refer to the list of supported cpu operations at https://www.mindspore.cn.";
|
||||
MS_LOG(WARNING) << "Not find op[" << kernel_name << "] in cpu. For more details, "
|
||||
<< "please refer to the list of supported cpu operations at https://www.mindspore.cn.";
|
||||
return support_kernel_attrs;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> support_kernel_attrs;
|
||||
auto inputs_ptr = op_info->inputs_ptr();
|
||||
auto outputs_ptr = op_info->outputs_ptr();
|
||||
if (outputs_ptr.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The output dimension of operator '" << kernel_name << "' should not be zero.";
|
||||
MS_LOG(WARNING) << "The output dimension of operator '" << kernel_name << "' should not be zero.";
|
||||
return support_kernel_attrs;
|
||||
}
|
||||
|
||||
auto support_size = outputs_ptr[0]->dtypes().size();
|
||||
|
|
|
@ -20,31 +20,48 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void LowerBoundCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
sorted_x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
values_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
namespace {
|
||||
constexpr size_t kInputsNum = 2;
|
||||
constexpr size_t kOutputsNum = 1;
|
||||
} // namespace
|
||||
|
||||
bool LowerBoundCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
if (inputs.size() != kInputsNum || outputs.size() != kOutputsNum) {
|
||||
MS_LOG(ERROR) << kernel_name_ << ": input and output size should be " << kInputsNum << " and " << kOutputsNum
|
||||
<< ", but get " << inputs.size() << " and " << outputs.size();
|
||||
return false;
|
||||
}
|
||||
workspace_size_list_.clear();
|
||||
InitInputOutputSize(inputs, outputs);
|
||||
sorted_x_shape_ = inputs[0]->GetShapeVector();
|
||||
values_shape_ = inputs[1]->GetShapeVector();
|
||||
output_shape_ = outputs[0]->GetShapeVector();
|
||||
size_t size_exp = 2;
|
||||
if (sorted_x_shape_.size() != values_shape_.size() || sorted_x_shape_.size() != size_exp ||
|
||||
sorted_x_shape_[0] != values_shape_[0]) {
|
||||
MS_LOG(EXCEPTION) << "The shape of input is invalid.";
|
||||
MS_LOG(ERROR) << "The shape of input is invalid.";
|
||||
return false;
|
||||
}
|
||||
sorted_x_num_ = sorted_x_shape_[0] * sorted_x_shape_[1];
|
||||
values_num_ = values_shape_[0] * values_shape_[1];
|
||||
output_num_ = output_shape_[0] * output_shape_[1];
|
||||
sorted_x_num_ = static_cast<size_t>(sorted_x_shape_[0] * sorted_x_shape_[1]);
|
||||
values_num_ = static_cast<size_t>(values_shape_[0] * values_shape_[1]);
|
||||
output_num_ = static_cast<size_t>(output_shape_[0] * output_shape_[1]);
|
||||
if (values_num_ != output_num_) {
|
||||
MS_LOG(EXCEPTION) << "Infer the shape of output error.";
|
||||
MS_LOG(ERROR) << "Infer the shape of output error.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, LowerBoundFunc> &pair) { return pair.first; });
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list);
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "LowerBound does not support this kernel data type: " << kernel_attr;
|
||||
MS_LOG(ERROR) << "LowerBound does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename I, typename O>
|
||||
|
@ -53,8 +70,8 @@ bool LowerBoundCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr>
|
|||
auto sorted_x_data_addr = reinterpret_cast<I *>(inputs[0]->addr);
|
||||
auto values_data_addr = reinterpret_cast<I *>(inputs[1]->addr);
|
||||
auto output_data_addr = reinterpret_cast<O *>(outputs[0]->addr);
|
||||
size_t sorted_x_data_column = sorted_x_shape_[1];
|
||||
size_t values_data_column = values_shape_[1];
|
||||
size_t sorted_x_data_column = static_cast<size_t>(sorted_x_shape_[1]);
|
||||
size_t values_data_column = static_cast<size_t>(values_shape_[1]);
|
||||
auto task = [this, &values_data_addr, &sorted_x_data_addr, &output_data_addr, &sorted_x_data_column,
|
||||
&values_data_column](size_t start, size_t end) {
|
||||
const size_t kTwo = 2;
|
||||
|
|
|
@ -23,12 +23,13 @@
|
|||
#include "plugin/factory/ms_factory.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class LowerBoundCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
class LowerBoundCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
LowerBoundCpuKernelMod() = default;
|
||||
~LowerBoundCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
|
@ -45,9 +46,9 @@ class LowerBoundCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, LowerBoundFunc>> func_list_;
|
||||
LowerBoundFunc kernel_func_;
|
||||
std::vector<size_t> sorted_x_shape_;
|
||||
std::vector<size_t> values_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
std::vector<int64_t> sorted_x_shape_;
|
||||
std::vector<int64_t> values_shape_;
|
||||
std::vector<int64_t> output_shape_;
|
||||
size_t sorted_x_num_;
|
||||
size_t values_num_;
|
||||
size_t output_num_;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "plugin/device/cpu/kernel/maximum_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "mindspore/core/ops/maximum.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -33,18 +34,30 @@ constexpr size_t kMaximumInputsNum = 2;
|
|||
constexpr size_t kMaximumOutputsNum = 1;
|
||||
} // namespace
|
||||
|
||||
void MaximumCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
input_x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
input_y_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
TypeId input_x_dtype = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
TypeId input_y_dtype = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
|
||||
bool MaximumCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::Maximum>(base_operator);
|
||||
if (!kernel_ptr) {
|
||||
MS_LOG(ERROR) << "cast Maximum ops failed!";
|
||||
return false;
|
||||
}
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
if (inputs.size() != kMaximumInputsNum || outputs.size() != kMaximumOutputsNum) {
|
||||
MS_LOG(ERROR) << kernel_name_ << ": input and output size should be " << kMaximumInputsNum << " and "
|
||||
<< kMaximumOutputsNum << ", but get " << inputs.size() << " and " << outputs.size();
|
||||
return false;
|
||||
}
|
||||
workspace_size_list_.clear();
|
||||
InitInputOutputSize(inputs, outputs);
|
||||
input_x_shape_ = inputs[0]->GetShapeVector();
|
||||
input_y_shape_ = inputs[1]->GetShapeVector();
|
||||
output_shape_ = outputs[0]->GetShapeVector();
|
||||
TypeId input_x_dtype = inputs[0]->GetDtype();
|
||||
TypeId input_y_dtype = inputs[1]->GetDtype();
|
||||
size_t max_input_shape_size =
|
||||
input_x_shape_.size() > input_y_shape_.size() ? input_x_shape_.size() : input_y_shape_.size();
|
||||
for (size_t i = 0; i < output_shape_.size(); i++) {
|
||||
output_num_ *= output_shape_[i];
|
||||
output_num_ *= static_cast<size_t>(output_shape_[i]);
|
||||
}
|
||||
if ((input_x_shape_.size() == 0 && input_y_shape_.size() != 0) ||
|
||||
(input_x_shape_.size() != 0 && input_y_shape_.size() == 0)) {
|
||||
|
@ -52,20 +65,23 @@ void MaximumCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
} else if (max_input_shape_size == output_shape_.size() && output_shape_.size() != 0) {
|
||||
InitInputTensors(input_x_dtype, input_y_dtype);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', inputs should be two tensors or one tensor and one scalar, but got " << input_x_dtype
|
||||
<< " and " << input_y_dtype;
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', inputs should be two tensors or one tensor and one scalar, but got "
|
||||
<< input_x_dtype << " and " << input_y_dtype;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, MaximumLaunchFunc> &pair) { return pair.first; });
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list);
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "Maximum does not support this kernel data type: " << kernel_attr;
|
||||
MS_LOG(ERROR) << "Maximum does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void MaximumCpuKernelMod::InitInputTensorAndScalar(size_t max_input_shape_size) {
|
||||
|
@ -149,18 +165,18 @@ void MaximumCpuKernelMod::InitTensorBroadcastShape() {
|
|||
broadcast_input_y_shape_.resize(max_dims_, 1);
|
||||
broadcast_output_shape_.resize(max_dims_, 1);
|
||||
for (size_t i = 0; i < output_shape_.size(); i++) {
|
||||
broadcast_output_shape_[i] = output_shape_[i];
|
||||
broadcast_output_shape_[i] = static_cast<size_t>(output_shape_[i]);
|
||||
}
|
||||
int input_x_dim_offset = output_shape_.size() - input_x_shape_.size();
|
||||
for (size_t j = 0; j < input_x_shape_.size(); j++) {
|
||||
broadcast_input_x_shape_[j + IntToSize(input_x_dim_offset)] = input_x_shape_[j];
|
||||
input_x_num_ *= input_x_shape_[j];
|
||||
broadcast_input_x_shape_[j + IntToSize(input_x_dim_offset)] = static_cast<size_t>(input_x_shape_[j]);
|
||||
input_x_num_ *= static_cast<size_t>(input_x_shape_[j]);
|
||||
}
|
||||
int input_y_dim_offset = output_shape_.size() - input_y_shape_.size();
|
||||
for (size_t k = 0; k < input_y_shape_.size(); k++) {
|
||||
if (need_broadcast_) {
|
||||
broadcast_input_y_shape_[k + IntToSize(input_y_dim_offset)] = input_y_shape_[k];
|
||||
input_y_num_ *= input_y_shape_[k];
|
||||
broadcast_input_y_shape_[k + IntToSize(input_y_dim_offset)] = static_cast<size_t>(input_y_shape_[k]);
|
||||
input_y_num_ *= static_cast<size_t>(input_y_shape_[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,12 +24,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class MaximumCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
class MaximumCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
MaximumCpuKernelMod() = default;
|
||||
~MaximumCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
|
@ -72,9 +73,9 @@ class MaximumCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
size_t input_x_num_{1};
|
||||
size_t input_y_num_{1};
|
||||
size_t output_num_{1};
|
||||
std::vector<size_t> input_x_shape_;
|
||||
std::vector<size_t> input_y_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
std::vector<int64_t> input_x_shape_;
|
||||
std::vector<int64_t> input_y_shape_;
|
||||
std::vector<int64_t> output_shape_;
|
||||
std::vector<size_t> broadcast_input_x_shape_;
|
||||
std::vector<size_t> broadcast_input_y_shape_;
|
||||
std::vector<size_t> broadcast_output_shape_;
|
||||
|
|
Loading…
Reference in New Issue