forked from mindspore-Ecosystem/mindspore
!7655 【MSLITE】Weight quant add condition stop in kmeans method
Merge pull request !7655 from ghzl/weight_quant_add_condition_stop
This commit is contained in:
commit
934bb07656
|
@ -412,11 +412,13 @@ static std::vector<float> InitClusters(float *data, size_t elem_count, size_t k)
|
|||
std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epochs, schema::QuantParamT *quantParam) {
|
||||
std::vector<float> clusters = InitClusters(data, elem_count, k);
|
||||
std::vector<int8_t> clusters_index{};
|
||||
double error{0};
|
||||
if (clusters.size() < k) {
|
||||
MS_LOG(WARNING) << "K is less than the size of data so KMeans function is not executed.";
|
||||
return clusters_index;
|
||||
}
|
||||
for (size_t epoch = 0; epoch < epochs; epoch++) {
|
||||
double error_cur{0};
|
||||
clusters_index.clear();
|
||||
std::vector<std::vector<float>> clusters_data(clusters.size());
|
||||
for (size_t i = 0; i < elem_count; i++) {
|
||||
|
@ -436,6 +438,15 @@ std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epoc
|
|||
clusters[j] = std::accumulate(clusters_data[j].begin(), clusters_data[j].end(), 0.0) / clusters_data[j].size();
|
||||
}
|
||||
}
|
||||
// compare error
|
||||
for (size_t j = 0; j < elem_count; j++) {
|
||||
error_cur += pow(data[j] - clusters[clusters_index[j]], 2);
|
||||
}
|
||||
error_cur = pow(error_cur / elem_count, 0.5);
|
||||
if (std::abs((error_cur - error) / error_cur) < 1e-6) {
|
||||
break;
|
||||
}
|
||||
error = error_cur;
|
||||
}
|
||||
// update data
|
||||
quantParam->clusters = clusters;
|
||||
|
|
Loading…
Reference in New Issue