knn code is done
This commit is contained in:
parent
b4026d37cf
commit
5421ff2abe
61
KNN.py
61
KNN.py
|
@ -6,6 +6,7 @@ from pyspark.sql import SQLContext
|
|||
from pyspark.mllib.linalg import Vectors
|
||||
import numpy as np
|
||||
from operator import add
|
||||
from os import listdir
|
||||
|
||||
def computeDist(a,b):
|
||||
return np.sqrt(np.sum((a-b)**2))
|
||||
|
@ -13,19 +14,9 @@ class KNN:
|
|||
'''
|
||||
KNN is used for classification
|
||||
'''
|
||||
def __init__(self,featuresCol="features", labelCol="label",):
|
||||
def __init__(self,featuresCol="features", labelCol="label"):
|
||||
self.featuresCol,self.labelCol = featuresCol,labelCol
|
||||
|
||||
#:dataSet.map(lambda row:row[1],X[0].toArray() + row[0].toArray())
|
||||
|
||||
|
||||
def classifyX(X):
|
||||
'''
|
||||
classify a unlabeled point
|
||||
param: X: a unlabeled point
|
||||
'''
|
||||
print X[0]
|
||||
return X[0]
|
||||
def classify(self,inXs,dataSet,k=10):
|
||||
'''
|
||||
classify unlabeled points in inXs.
|
||||
|
@ -36,7 +27,7 @@ class KNN:
|
|||
if len(inXs) != len(dataSet.first()[0].values):
|
||||
print "length of features of inXs is not corresponding with dataset's"
|
||||
return
|
||||
dis = dataSet.map(lambda row: (row[1],computeDist(row[0].toArray(),np.array(inXs))))
|
||||
dis = dataSet.map(lambda row: (row[1],computeDist(row[0].toArray(),inXs.toArray())))
|
||||
|
||||
def f(x):
|
||||
print x
|
||||
|
@ -45,19 +36,45 @@ class KNN:
|
|||
#print orderedDis
|
||||
|
||||
groupLabel = sc.parallelize(orderedDis).map(lambda row:(row[0],1)).reduceByKey(add).takeOrdered(1,key=lambda row:-row[1])[0][0]
|
||||
print groupLabel
|
||||
return groupLabel
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def load_data(data_folder):
|
||||
file_list=listdir(data_folder)
|
||||
file_num=len(file_list)
|
||||
datas = list()
|
||||
for i in range(file_num):
|
||||
filename=file_list[i]
|
||||
fr=open('%s/%s' %(data_folder,filename))
|
||||
data_in_line = list()
|
||||
for j in range(32):
|
||||
line_str=fr.readline()
|
||||
for k in range(32):
|
||||
data_in_line.append(int(line_str[k]))
|
||||
|
||||
|
||||
label = filename.split('.')[0].split("_")[0]
|
||||
# print "file:%s,label is %s"%(filename,label)
|
||||
datas.append((Vectors.dense(data_in_line),float(label)))
|
||||
return datas
|
||||
|
||||
if __name__ == "__main__":
|
||||
sc = SparkContext(appName="KNN")
|
||||
sqlContext = SQLContext(sc)
|
||||
dataSet = sqlContext.createDataFrame([(Vectors.dense([2,3,4,5]),1),
|
||||
(Vectors.dense([1,2,3,4]),1),
|
||||
|
||||
(Vectors.dense([3,4,5,6]),2)],
|
||||
['features','label'])
|
||||
inXs = [1,2,3,4]
|
||||
|
||||
count,errorCount = 0,0
|
||||
knn = KNN()
|
||||
knn.classify(inXs,dataSet,10)
|
||||
|
||||
datasetDF = sqlContext.createDataFrame(load_data("train"),["features","label"]).cache()
|
||||
testData = load_data("test")
|
||||
|
||||
for x in testData:
|
||||
prediction = knn.classify(x[0],datasetDF,10)
|
||||
print "%d-%d" %(x[1],prediction)
|
||||
if prediction != x[1]:
|
||||
errorCount += 1
|
||||
count += 1
|
||||
print "error rate is %f(%d/%d)" % (1.0 * errorCount / count,errorCount,count)
|
||||
|
||||
|
Loading…
Reference in New Issue