knn code is done

This commit is contained in:
StarLee 2016-06-22 16:29:34 +08:00
parent b4026d37cf
commit 5421ff2abe
1 changed files with 39 additions and 22 deletions

61
KNN.py
View File

@ -6,6 +6,7 @@ from pyspark.sql import SQLContext
from pyspark.mllib.linalg import Vectors
import numpy as np
from operator import add
from os import listdir
def computeDist(a,b):
return np.sqrt(np.sum((a-b)**2))
@ -13,19 +14,9 @@ class KNN:
'''
KNN is used for classification
'''
def __init__(self,featuresCol="features", labelCol="label",):
def __init__(self,featuresCol="features", labelCol="label"):
self.featuresCol,self.labelCol = featuresCol,labelCol
#:dataSet.map(lambda row:row[1],X[0].toArray() + row[0].toArray())
def classifyX(X):
'''
classify a unlabeled point
param: X: a unlabeled point
'''
print X[0]
return X[0]
def classify(self,inXs,dataSet,k=10):
'''
classify unlabeled points in inXs.
@ -36,7 +27,7 @@ class KNN:
if len(inXs) != len(dataSet.first()[0].values):
print "length of features of inXs is not corresponding with dataset's"
return
dis = dataSet.map(lambda row: (row[1],computeDist(row[0].toArray(),np.array(inXs))))
dis = dataSet.map(lambda row: (row[1],computeDist(row[0].toArray(),inXs.toArray())))
def f(x):
print x
@ -45,19 +36,45 @@ class KNN:
#print orderedDis
groupLabel = sc.parallelize(orderedDis).map(lambda row:(row[0],1)).reduceByKey(add).takeOrdered(1,key=lambda row:-row[1])[0][0]
print groupLabel
return groupLabel
def load_data(data_folder):
file_list=listdir(data_folder)
file_num=len(file_list)
datas = list()
for i in range(file_num):
filename=file_list[i]
fr=open('%s/%s' %(data_folder,filename))
data_in_line = list()
for j in range(32):
line_str=fr.readline()
for k in range(32):
data_in_line.append(int(line_str[k]))
label = filename.split('.')[0].split("_")[0]
# print "file:%s,label is %s"%(filename,label)
datas.append((Vectors.dense(data_in_line),float(label)))
return datas
if __name__ == "__main__":
sc = SparkContext(appName="KNN")
sqlContext = SQLContext(sc)
dataSet = sqlContext.createDataFrame([(Vectors.dense([2,3,4,5]),1),
(Vectors.dense([1,2,3,4]),1),
(Vectors.dense([3,4,5,6]),2)],
['features','label'])
inXs = [1,2,3,4]
count,errorCount = 0,0
knn = KNN()
knn.classify(inXs,dataSet,10)
datasetDF = sqlContext.createDataFrame(load_data("train"),["features","label"]).cache()
testData = load_data("test")
for x in testData:
prediction = knn.classify(x[0],datasetDF,10)
print "%d-%d" %(x[1],prediction)
if prediction != x[1]:
errorCount += 1
count += 1
print "error rate is %f(%d/%d)" % (1.0 * errorCount / count,errorCount,count)