pangu base

This commit is contained in:
duzhixing 2023-01-04 17:50:20 +08:00
parent 6c9fd1bd63
commit 87fa335b92
17 changed files with 753 additions and 4 deletions

File diff suppressed because one or more lines are too long

View File

@ -274,6 +274,7 @@ constexpr auto kFillOpName = "Fill";
constexpr auto kFillDOpName = "FillD";
constexpr auto kFillV2OpName = "FillV2";
constexpr auto kFillV2DOpName = "FillV2D";
constexpr auto kFSEDecodeOpName = "FSEDecode";
constexpr auto kFive2FourOpName = "Five2Four";
constexpr auto kFlattenGradOpName = "FlattenGrad";
constexpr auto kFour2FiveOpName = "Four2Five";
@ -484,6 +485,7 @@ constexpr auto kRandomChoiceWithMaskOpName = "RandomChoiceWithMask";
constexpr auto kRandomShuffleOpName = "RandomShuffle";
constexpr auto kRangeOpName = "Range";
constexpr auto kRangeDOpName = "RangeD";
constexpr auto kQuantDTypeCastOpName = "QuantDTypeCast";
constexpr auto kRealDivOpName = "RealDiv";
constexpr auto kReciprocalOpName = "Reciprocal";
constexpr auto kRecvOpName = "StreamRecv";

View File

@ -42,6 +42,8 @@ if(EXISTS ${CMAKE_C_COMPILER} AND EXISTS ${CMAKE_CXX_COMPILER})
${CMAKE_CURRENT_SOURCE_DIR}/slice_grad_kernel.cc
${CMAKE_CURRENT_SOURCE_DIR}/random_shuffle_kernel.cc
${CMAKE_CURRENT_SOURCE_DIR}/range_kernel.cc
${CMAKE_CURRENT_SOURCE_DIR}/quant_dtype_cast_kernel.cc
${CMAKE_CURRENT_SOURCE_DIR}/fse_decode_kernel.cc
${CMAKE_CURRENT_SOURCE_DIR}/replay_buffer/reservoir_replay_buffer.cc
${CMAKE_CURRENT_SOURCE_DIR}/replay_buffer/reservoir_replay_buffer_kernels.cc
)

View File

@ -0,0 +1,244 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/kernel/aicpu/aicpu_ops/fse_decode_kernel.h"
#include <Eigen/Dense>
#include <vector>
#include <string>
#include <thread>
#include <functional>
#include "proto/aicpu_tensor.pb.h"
#include "aicpu_sharder/aicpu_sharder.h"
#include "mindspore/core/mindapi/base/type_id.h"
namespace aicpu {
namespace {
// constexpr size_t kFSEDecodeOutputShapeRank = 2;
const size_t hw_h = 1;
const size_t hw_w = 2;
const size_t fnz_w1 = 4;
const size_t fnz_h1 = 3;
const size_t fnz_h0 = 2;
const size_t fnz_w0 = 1;
const size_t C0NUM = 0;
const size_t C1NUM = 1;
const size_t C2NUM = 2;
const size_t C3NUM = 3;
const size_t C4NUM = 4;
const size_t C5NUM = 5;
const size_t C6NUM = 6;
const size_t C7NUM = 7;
bool TransShapeToHW_NZ(const std::vector<int> &host_shape, std::vector<int> *hw_shape) {
if (host_shape.empty()) {
return false;
}
switch (host_shape.size()) {
case 1:
hw_shape->push_back(1);
hw_shape->push_back(1);
hw_shape->push_back(host_shape[0]);
return true;
default:
auto size = host_shape.size();
if (size < C2NUM) {
return false;
}
int64_t times = 1;
for (size_t i = 0; i != size - C2NUM; i++) {
times *= host_shape[i];
}
hw_shape->push_back(times);
hw_shape->push_back(host_shape[size - C2NUM]);
hw_shape->push_back(host_shape[size - 1]);
return true;
}
}
int NCHW_TO_FRAC_NZ(const std::vector<int> &host_shape, const std::vector<int> &device_shape,
std::vector<int> *result) {
std::vector<int> hw_shape;
if (!TransShapeToHW_NZ(host_shape, &hw_shape)) {
return -1;
}
auto element_num =
std::accumulate(device_shape.begin(), device_shape.end(), static_cast<int>(1), std::multiplies<int>());
result->resize(element_num);
auto times = hw_shape.at(0);
auto h = hw_shape.at(hw_h);
auto w = hw_shape.at(hw_w);
auto hw = h * w;
auto shape_size = device_shape.size();
auto w1 = device_shape[shape_size - fnz_w1];
auto h1 = device_shape[shape_size - fnz_h1];
auto h0 = device_shape[shape_size - fnz_h0];
auto w0 = device_shape[shape_size - fnz_w0];
auto h1h0w0 = h1 * h0 * w0;
auto w1h1h0w0 = w1 * h1h0w0;
auto num_w1 = w / w0;
for (int64_t times_idx = 0; times_idx < times; times_idx++) {
auto times_head = times_idx * w1h1h0w0;
auto src_times_head = times_idx * hw;
for (int64_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
auto h1h0_head = times_head + h1h0_idx * w0;
auto src_h_head = src_times_head + h1h0_idx * w;
for (int64_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
for (int64_t i = 0; i < w0; ++i) {
int64_t src_idx = src_h_head + w1_idx * w0 + i;
int64_t dst_idx = h1h0_head + w1_idx * h1h0w0 + i;
result->at(src_idx) = dst_idx;
}
}
auto w1_head = num_w1 * w0;
for (int64_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
auto src_w_idx = w1_head + w0_idx;
int64_t dst_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx;
int64_t src_idx = src_h_head + src_w_idx;
result->at(src_idx) = dst_idx;
}
}
}
return 0;
}
std::vector<int> GetShape(const ::aicpuops::TensorShape &shape) {
std::vector<int> res;
for (int i = 0; i < shape.dim_size(); ++i) {
res.push_back(shape.dim(i).size());
}
return res;
}
} // namespace
bool FSEDecodeKernel::CheckParams() { return true; }
uint64_t FSEDecodeKernel::Pop(const uint64_t *chunks, uint8_t bit_count) {
const int kMaxBitCount = 64;
uint64_t right = curr_chunk_ >> static_cast<size_t>(kMaxBitCount - curr_bit_count_);
uint64_t res = right & ((1u << bit_count) - 1);
curr_bit_count_ -= static_cast<int8_t>(bit_count);
if (curr_bit_count_ > 0) {
return res;
}
if (curr_bit_count_ == 0) {
if (curr_chunk_index_ > -1) {
curr_bit_count_ = kMaxBitCount;
curr_chunk_ = chunks[curr_chunk_index_--];
}
return res;
}
curr_bit_count_ += static_cast<int8_t>(bit_count);
curr_chunk_ = chunks[curr_chunk_index_--];
right |= (curr_chunk_ & ((1u << (static_cast<int8_t>(bit_count) - curr_bit_count_)) - 1)) << curr_bit_count_;
curr_bit_count_ = kMaxBitCount - (static_cast<int8_t>(bit_count) - curr_bit_count_);
return right;
}
void FSEDecodeKernel::FixedBitFloatDequantTask() {
uint64_t *chunks = reinterpret_cast<uint64_t *>(io_addrs_[C0NUM]);
uint16_t *states_table = reinterpret_cast<uint16_t *>(io_addrs_[C1NUM]);
uint8_t *bit_count_table = reinterpret_cast<uint8_t *>(io_addrs_[C2NUM]);
uint16_t *symbol_table = reinterpret_cast<uint16_t *>(io_addrs_[C3NUM]);
float *centroids = reinterpret_cast<float *>(io_addrs_[C4NUM]);
float *output = reinterpret_cast<float *>(io_addrs_[C6NUM]);
int out_count = std::accumulate(output_shape_.begin(), output_shape_.end(), 1, std::multiplies<int>());
uint64_t state = Pop(chunks, table_log_);
while ((curr_chunk_index_ >= 0) || (bit_count_table[state] == 0) || (curr_bit_count_ > 0)) {
if (out_count == 0) {
return;
}
output[--out_count] = static_cast<float>(centroids[symbol_table[state]]);
// state = newStateBaseline + rest
state = states_table[state] + Pop(chunks, bit_count_table[state]);
}
}
void FSEDecodeKernel::FixedBitHalfDequantTask() {
uint64_t *chunks = reinterpret_cast<uint64_t *>(io_addrs_[C0NUM]);
uint16_t *states_table = reinterpret_cast<uint16_t *>(io_addrs_[C1NUM]);
uint8_t *bit_count_table = reinterpret_cast<uint8_t *>(io_addrs_[C2NUM]);
uint16_t *symbol_table = reinterpret_cast<uint16_t *>(io_addrs_[C3NUM]);
float *centroids = reinterpret_cast<float *>(io_addrs_[C4NUM]);
int *intput_shape = reinterpret_cast<int *>(io_addrs_[C5NUM]);
Eigen::half *output = reinterpret_cast<Eigen::half *>(io_addrs_[C6NUM]);
int out_count = std::accumulate(output_shape_.begin(), output_shape_.end(), 1, std::multiplies<int>());
std::vector<int> input_shape_vector(intput_shape, intput_shape + input_shape_size_);
std::vector<int> index_maps;
NCHW_TO_FRAC_NZ(input_shape_vector, output_shape_, &index_maps);
uint64_t state = Pop(chunks, table_log_);
while ((curr_chunk_index_ >= 0) || (bit_count_table[state] == 0) || (curr_bit_count_ > 0)) {
if (out_count == 0) {
return;
}
auto out_index = index_maps.at(--out_count);
output[out_index] = static_cast<Eigen::half>(centroids[symbol_table[state]]);
// state = newStateBaseline + rest
state = states_table[state] + Pop(chunks, bit_count_table[state]);
}
}
uint32_t FSEDecodeKernel::FSEDecodeTask() {
if (io_addrs_.empty() || io_addrs_.size() != C7NUM) {
return kAicpuKernelStateFailed;
}
if (dst_type_ == mindspore::kNumberTypeFloat32) {
FixedBitFloatDequantTask();
} else if (dst_type_ == mindspore::kNumberTypeFloat16) {
FixedBitHalfDequantTask();
} else {
return kAicpuKernelStateInvalid;
}
return kAicpuKernelStateSucess;
}
uint32_t FSEDecodeKernel::ParseKernelParam() {
::google::protobuf::Map<::std::string, ::aicpuops::AttrValue> attrs = node_def_.attrs();
// get value of attr axis
dst_type_ = attrs["dst_t"].i();
curr_chunk_ = attrs["curr_chunk"].i();
curr_chunk_index_ = attrs["curr_chunk_index"].i();
curr_bit_count_ = attrs["curr_bit_count"].i();
table_log_ = attrs["table_log"].i();
// get input tensors shape
if (node_def_.inputs_size() != C6NUM) {
AICPU_LOGE("For 'FSEDecode', input tensor number must be 1, but got %d", node_def_.inputs_size());
return kAicpuKernelStateInvalid;
}
input_shape_size_ = node_def_.inputs(C5NUM).tensor_shape().dim(0).size();
// get output tensor shape
if (node_def_.outputs_size() != 1) {
AICPU_LOGE("For 'FSEDecode', output tensor number must be 1, but got %d", node_def_.outputs_size());
return kAicpuKernelStateInvalid;
}
aicpuops::Tensor output_tensor = node_def_.outputs(0);
output_shape_ = GetShape(output_tensor.tensor_shape());
return kAicpuKernelStateSucess;
}
uint32_t FSEDecodeKernel::DoCompute() { return FSEDecodeTask(); }
} // namespace aicpu
extern "C" {
__attribute__((visibility("default"))) uint32_t FSEDecode(void *param) {
aicpu::FSEDecodeKernel fse_decode_kernel;
return fse_decode_kernel.Compute(param);
}
}

View File

@ -0,0 +1,47 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_OPS_FSE_DECODE_KERNEL_H_
#define AICPU_OPS_FSE_DECODE_KERNEL_H_
#include <vector>
#include <random>
#include "common/kernel_base.h"
namespace aicpu {
class FSEDecodeKernel : public KernelBase {
public:
FSEDecodeKernel() : KernelBase("FSEDecodeKernel") {}
~FSEDecodeKernel() = default;
protected:
uint32_t ParseKernelParam() override;
uint32_t DoCompute() override;
bool CheckParams();
uint32_t FSEDecodeTask();
void FixedBitHalfDequantTask();
void FixedBitFloatDequantTask();
uint64_t Pop(const uint64_t *chunks, uint8_t bit_count);
std::vector<int> output_shape_;
int64_t input_shape_size_;
int64_t dst_type_{0};
uint64_t curr_chunk_{0};
int64_t curr_chunk_index_{0};
int64_t curr_bit_count_{0};
int64_t table_log_{0};
};
} // namespace aicpu
#endif // AICPU_OPS_FSE_DECODE_KERNEL_H_

View File

@ -0,0 +1,168 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/kernel/aicpu/aicpu_ops/quant_dtype_cast_kernel.h"
#include <Eigen/Dense>
#include <vector>
#include <string>
#include <thread>
#include <functional>
#include "proto/aicpu_tensor.pb.h"
#include "aicpu_sharder/aicpu_sharder.h"
#include "mindspore/core/mindapi/base/type_id.h"
namespace aicpu {
namespace {
constexpr size_t C0NUM = 0;
constexpr size_t C1NUM = 1;
constexpr size_t C2NUM = 2;
constexpr size_t C3NUM = 3;
constexpr size_t C4NUM = 4;
constexpr size_t C5NUM = 5;
constexpr size_t C6NUM = 6;
std::vector<int64_t> GetShape(const ::aicpuops::TensorShape &shape) {
std::vector<int64_t> res;
for (int i = 0; i < shape.dim_size(); ++i) {
res.push_back(shape.dim(i).size());
}
return res;
}
} // namespace
bool QuantDTypeCastKernel::CheckParams() { return true; }
void QuantDTypeCastKernel::FixedBitFloatDequantTask() {
int8_t *input = reinterpret_cast<int8_t *>(io_addrs_[C0NUM]);
float *scales = reinterpret_cast<float *>(io_addrs_[C1NUM]);
int *zps = reinterpret_cast<int *>(io_addrs_[C2NUM]);
float *mean_corrs = reinterpret_cast<float *>(io_addrs_[C3NUM]);
float *var_corrs = reinterpret_cast<float *>(io_addrs_[C4NUM]);
float *output = reinterpret_cast<float *>(io_addrs_[C5NUM]);
// optimize in the pass.
int element_cnt = std::accumulate(input_shapes_.begin(), input_shapes_.end(), 1, std::multiplies<int64_t>());
if (quant_param_size_ == 1) {
auto dequant = [&](size_t start, size_t end) {
for (size_t pos = start; pos < end; pos++) {
// formula: dequant = (x - zp) * scale
output[pos] = (input[pos] - zps[0]) * scales[0] * var_corrs[0] + mean_corrs[0];
}
};
const int64_t per_unit_size = element_cnt / std::thread::hardware_concurrency();
ParallelFor(element_cnt, per_unit_size, dequant);
} else {
auto bucket_count = input_shapes_[axis_];
size_t stride = 1;
for (size_t i = axis_ + 1; i < input_shapes_.size(); i++) {
stride *= input_shapes_[i];
}
auto dequant = [&](size_t start, size_t end) {
for (size_t pos = start; pos < end; pos++) {
size_t bucket_index = (pos / stride) % bucket_count;
// formula: dequant = (x - zp) * scale
output[pos] =
(input[pos] - zps[bucket_index]) * scales[bucket_index] * var_corrs[bucket_index] + mean_corrs[bucket_index];
}
};
const int64_t per_unit_size = element_cnt / std::thread::hardware_concurrency();
ParallelFor(element_cnt, per_unit_size, dequant);
}
}
void QuantDTypeCastKernel::FixedBitHalfDequantTask() {
int8_t *input = reinterpret_cast<int8_t *>(io_addrs_[C0NUM]);
float *scales = reinterpret_cast<float *>(io_addrs_[C1NUM]);
int *zps = reinterpret_cast<int *>(io_addrs_[C2NUM]);
float *mean_corrs = reinterpret_cast<float *>(io_addrs_[C3NUM]);
float *var_corrs = reinterpret_cast<float *>(io_addrs_[C4NUM]);
Eigen::half *output = reinterpret_cast<Eigen::half *>(io_addrs_[C5NUM]);
// optimize in the pass.
int element_cnt = std::accumulate(input_shapes_.begin(), input_shapes_.end(), 1, std::multiplies<int64_t>());
if (quant_param_size_ == 1) {
auto dequant = [&](size_t start, size_t end) {
for (size_t pos = start; pos < end; pos++) {
// formula: dequant = (x - zp) * scale
output[pos] = Eigen::half((input[pos] - zps[0]) * scales[0] * var_corrs[0] + mean_corrs[0]);
}
};
const int64_t per_unit_size = element_cnt / std::thread::hardware_concurrency();
ParallelFor(element_cnt, per_unit_size, dequant);
} else {
auto bucket_count = input_shapes_[axis_];
size_t stride = 1;
for (size_t i = axis_ + 1; i < input_shapes_.size(); i++) {
stride *= input_shapes_[i];
}
auto dequant = [&](size_t start, size_t end) {
for (size_t pos = start; pos < end; pos++) {
size_t bucket_index = (pos / stride) % bucket_count;
// formula: dequant = (x - zp) * scale
output[pos] = Eigen::half((input[pos] - zps[bucket_index]) * scales[bucket_index] * var_corrs[bucket_index] +
mean_corrs[bucket_index]);
}
};
const int64_t per_unit_size = element_cnt / std::thread::hardware_concurrency();
ParallelFor(element_cnt, per_unit_size, dequant);
}
}
uint32_t QuantDTypeCastKernel::QuantDTypeCastTask() {
if (io_addrs_.empty() || io_addrs_.size() != C6NUM) {
return kAicpuKernelStateFailed;
}
if (dst_type_ == mindspore::kNumberTypeFloat32) {
FixedBitFloatDequantTask();
} else if (dst_type_ == mindspore::kNumberTypeFloat16) {
FixedBitHalfDequantTask();
} else {
return kAicpuKernelStateInvalid;
}
return kAicpuKernelStateSucess;
}
uint32_t QuantDTypeCastKernel::ParseKernelParam() {
::google::protobuf::Map<::std::string, ::aicpuops::AttrValue> attrs = node_def_.attrs();
// get value of attr axis
axis_ = attrs["axis"].i();
dst_type_ = attrs["dst_t"].i();
src_type_ = attrs["src_t"].i();
// get input tensors shape
if (node_def_.inputs_size() != C5NUM) {
AICPU_LOGE("For 'QuantDTypeCast', input tensor number must be 1, but got %d", node_def_.inputs_size());
return kAicpuKernelStateInvalid;
}
aicpuops::Tensor input_tensor = node_def_.inputs(0);
input_shapes_ = GetShape(input_tensor.tensor_shape());
quant_param_size_ = node_def_.inputs(1).tensor_shape().dim(0).size();
// get output tensor shape
if (node_def_.outputs_size() != 1) {
AICPU_LOGE("For 'QuantDTypeCast', output tensor number must be 1, but got %d", node_def_.outputs_size());
return kAicpuKernelStateInvalid;
}
return kAicpuKernelStateSucess;
}
uint32_t QuantDTypeCastKernel::DoCompute() { return QuantDTypeCastTask(); }
} // namespace aicpu
extern "C" {
__attribute__((visibility("default"))) uint32_t QuantDTypeCast(void *param) {
aicpu::QuantDTypeCastKernel quant_dtype_cast_kernel;
return quant_dtype_cast_kernel.Compute(param);
}
}

View File

@ -0,0 +1,44 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_OPS_QUANT_DTYPE_CAST_KERNEL_H_
#define AICPU_OPS_QUANT_DTYPE_CAST_KERNEL_H_
#include <vector>
#include <random>
#include "common/kernel_base.h"
namespace aicpu {
class QuantDTypeCastKernel : public KernelBase {
public:
QuantDTypeCastKernel() : KernelBase("QuantDTypeCastKernel") {}
~QuantDTypeCastKernel() = default;
protected:
uint32_t ParseKernelParam() override;
uint32_t DoCompute() override;
bool CheckParams();
uint32_t QuantDTypeCastTask();
void FixedBitHalfDequantTask();
void FixedBitFloatDequantTask();
std::vector<int64_t> input_shapes_;
int64_t quant_param_size_;
int64_t axis_{0};
int64_t dst_type_{0};
int64_t src_type_{0};
};
} // namespace aicpu
#endif // AICPU_OPS_QUANT_DTYPE_CAST_KERNEL_H_

View File

@ -195,6 +195,8 @@ constexpr auto kSmoothL1Loss = "SmoothL1Loss";
constexpr auto kSmoothL1LossGrad = "SmoothL1LossGrad";
constexpr auto kSparseCross = "SparseCross";
constexpr auto kChannelShuffle = "ChannelShuffle";
constexpr auto kQuantDTypeCast = "QuantDTypeCast";
constexpr auto kFSEDecode = "FSEDecode";
const std::set<std::string> kCpuKernelOps{kIdentity,
kMaskedSelect,
@ -279,7 +281,9 @@ const std::set<std::string> kCpuKernelBaseOps{kDropoutGenMaskOpName,
kConcatOffset,
kSliceGrad,
kRandomShuffle,
kRange};
kRange,
kQuantDTypeCast,
kFSEDecode};
const std::set<std::string> kDynamicInputOps{kRaggedTensorToTensor,
kSparseCross,
kRaggedTensorToSparse,

View File

@ -170,6 +170,8 @@
#include "plugin/device/ascend/optimizer/mindir/all_to_all_unify_mindir.h"
#include "plugin/device/ascend/optimizer/mindir/neighbor_exchange_v2_unify_mindir.h"
#include "plugin/device/ascend/optimizer/mindir/ascend_vm_op_adapter.h"
#include "plugin/device/ascend/optimizer/mindir/quant_dtype_cast_adjust.h"
#include "plugin/device/ascend/optimizer/mindir/fse_decode_adjust.h"
#include "backend/common/pass/adjust_depend_for_parallel_optimizer_recompute_all_gather.h"
#include "backend/common/pass/gradients_allreduce_depend_last_send.h"
#include "backend/common/pass/optimize_gradients_allreduce_overlap.h"
@ -701,6 +703,8 @@ void AscendUnifyMindIR(const std::shared_ptr<session::KernelGraph> &kernel_graph
unify_mindir_pm->AddPass(std::make_shared<opt::NeighborExchangeV2GradUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::AllToAllUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::AICpuLibSelectPass>());
unify_mindir_pm->AddPass(std::make_shared<opt::QuantDTypeCastAdjust>());
unify_mindir_pm->AddPass(std::make_shared<opt::FSEDecodeAdjust>());
optimizer->AddPassManager(unify_mindir_pm);
(void)optimizer->Optimize(kernel_graph);

View File

@ -47,7 +47,9 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An
kConcatOffsetOpName,
kSliceGradOpName,
kRandomShuffleOpName,
kRangeOpName};
kRangeOpName,
kQuantDTypeCastOpName,
kFSEDecodeOpName};
static const std::set<std::string> kMigrateAicpuKernelOps = {mindspore::kAdaptiveAvgPool2dOpName,
mindspore::kAdaptiveAvgPool2dGradOpName,
mindspore::kCacheSwapTableOpName,

View File

@ -0,0 +1,50 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/optimizer/mindir/fse_decode_adjust.h"
#include <vector>
#include <memory>
#include "include/common/utils/utils.h"
#include "utils/ms_context.h"
#include "backend/common/optimizer/helper.h"
#include "include/common/utils/anfalgo.h"
#include "utils/trace_base.h"
#include "ops/op_name.h"
namespace mindspore {
namespace opt {
const BaseRef FSEDecodeAdjust::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
auto prim = std::make_shared<Primitive>(kFSEDecodeOpName);
return VectorRef({prim, Xs});
}
const AnfNodePtr FSEDecodeAdjust::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(func_graph);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto primitive = common::AnfAlgo::GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(primitive);
primitive->DelAttr("format");
primitive->DelAttr("infer_done");
MS_LOG(INFO) << cnode->fullname_with_scope() << " run FSEDecodeAdjust pass.";
return node;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,34 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FSE_DECODE_ADJUST_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FSE_DECODE_ADJUST_H_
#include <vector>
#include "backend/common/optimizer/optimizer.h"
#include "backend/common/optimizer/helper.h"
namespace mindspore {
namespace opt {
class FSEDecodeAdjust : public PatternProcessPass {
public:
explicit FSEDecodeAdjust(bool multigraph = true) : PatternProcessPass("fse_decode_adjust", multigraph) {}
~FSEDecodeAdjust() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FSE_DECODE_ADJUST_H_

View File

@ -0,0 +1,73 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/optimizer/mindir/quant_dtype_cast_adjust.h"
#include <vector>
#include <memory>
#include <string>
#include "include/common/utils/utils.h"
#include "utils/ms_context.h"
#include "backend/common/optimizer/helper.h"
#include "include/common/utils/anfalgo.h"
#include "utils/trace_base.h"
#include "ops/op_name.h"
#include "runtime/device/ms_device_shape_transfer.h"
namespace mindspore {
namespace opt {
const BaseRef QuantDTypeCastAdjust::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
auto prim = std::make_shared<Primitive>(kQuantDTypeCastOpName);
return VectorRef({prim, Xs});
}
const AnfNodePtr QuantDTypeCastAdjust::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(func_graph);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " run QuantDTypeCastAdjust pass.";
auto primitive = common::AnfAlgo::GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(primitive);
primitive->DelAttr("format");
primitive->DelAttr("infer_done");
std::vector<std::string> cnode_output_format = {AnfAlgo::GetOutputFormat(cnode, 0)};
if (!cnode_output_format.empty() && cnode_output_format.at(0) == "FRACTAL_NZ") {
auto param_node = cnode->input(1)->cast<ParameterPtr>();
auto tensor_info = param_node->default_param()->cast<tensor::TensorPtr>();
auto host_shape = tensor_info->shape_c();
auto size = tensor_info->Size();
auto host_ptr = tensor_info->data_c();
auto device_shape = trans::TransShapeToDevice(host_shape, kOpFormat_FRAC_NZ, kNumberTypeFloat16);
const trans::FormatArgs format_args{host_ptr, size, kOpFormat_NCHW, kOpFormat_FRAC_NZ,
host_shape, device_shape, kNumberTypeInt8};
auto host_tmp = std::vector<uint8_t>(size);
MS_LOG(DEBUG) << "TransFormat host_shape:" << host_shape << " device_shape:" << device_shape;
auto ret = trans::TransFormat(format_args, host_tmp.data(), cnode, 1);
if (!ret) {
MS_LOG(ERROR) << "Trans format failed.";
return nullptr;
}
if (memcpy_s(tensor_info->data_c(), tensor_info->Size(), host_tmp.data(), host_tmp.size()) != EOK) {
MS_LOG(ERROR) << "memcpy failed.";
return nullptr;
}
}
return node;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,34 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_QUANT_DTYPE_CAST_ADJUST_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_QUANT_DTYPE_CAST_ADJUST_H_
#include <vector>
#include "backend/common/optimizer/optimizer.h"
#include "backend/common/optimizer/helper.h"
namespace mindspore {
namespace opt {
class QuantDTypeCastAdjust : public PatternProcessPass {
public:
explicit QuantDTypeCastAdjust(bool multigraph = true) : PatternProcessPass("quant_dtype_cst_adjust", multigraph) {}
~QuantDTypeCastAdjust() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_FAKE_LEARNED_SCALE_QUANT_GRAD_UNIFY_MINDIR_H_

View File

@ -1132,6 +1132,8 @@ void AnfAlgo::ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list) {
result = DelayExecNode(result, kCastOpName, true);
result = DelayExecNode(result, kAdamApplyOneWithDecayOpName, false);
result = DelayExecNode(result, kAdamApplyOneOpName, false);
result = DelayExecNode(result, kQuantDTypeCastOpName, false);
result = DelayExecNode(result, kFSEDecodeOpName, false);
if (parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() > 1) {
result = DelayExecNode(result, kDropoutGenMaskOpName, true);
result = DelayExecNode(result, kStatelessDropOutGenMaskOpName, true);

View File

@ -22,6 +22,36 @@
namespace mindspore {
namespace ops {
namespace {
constexpr size_t kTableExtend = 3;
constexpr int kAlignOffset = 7;
abstract::ShapePtr QuantDTypeCastInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t kInputNum = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
auto x = input_args[0]->BuildShape();
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
auto shape_element = x->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_element);
return shape_element;
}
TypePtr QuantDTypeCastInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
const int64_t kInputNum = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
auto x_type = input_args[0]->BuildType();
const std::set<TypePtr> valid_types = {kInt8, kFloat32};
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim_name);
auto type_ptr = mindspore::TypeIdToType(static_cast<TypeId>(GetValue<int64_t>(prim->GetAttr(kDstT))));
return type_ptr;
}
} // namespace
MIND_API_OPERATOR_IMPL(QuantDTypeCast, BaseOperator);
void QuantDTypeCast::set_src_t(const int64_t src_t) { (void)AddAttr(kSrcT, api::MakeValue(src_t)); }
int64_t QuantDTypeCast::get_src_t() const {
@ -38,6 +68,13 @@ void QuantDTypeCast::Init(const int64_t src_t, const int64_t dst_t) {
this->set_dst_t(dst_t);
}
REGISTER_PRIMITIVE_C(kNameQuantDTypeCast, QuantDTypeCast);
AbstractBasePtr QuantDTypeCastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto infer_type = QuantDTypeCastInferType(primitive, input_args);
auto infer_shape = QuantDTypeCastInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(QuantDTypeCast, prim::kPrimQuantDTypeCast, QuantDTypeCastInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -165,7 +165,7 @@ int WeightQuantizer::LinearQuant(const FuncGraphPtr &func_graph, const CNodePtr
}
int preferred_dim = GetPreferredDim(cnode, idx - 1, ConvertShapeVectorToInt32(tensor_info->shape()));
if (quant_strategy_ != nullptr && !quant_strategy_->CanTensorQuantized(cnode, input, preferred_dim)) {
MS_LOG(INFO) << input->fullname_with_scope() << " is not quantizable";
MS_LOG(INFO) << input->fullname_with_scope() << " will not quantify";
continue;
}
// support for matmul shared weight