|
0-9数字识别,NMIST数据的识别。
具体代码包括NMIST见附件中。
参考资料是TOM的机器学习BP那一章。
# coding:utf-8
# 没考虑大小端
import struct
import numpy
def loadImages(filename):
try:
f = open(filename,'rb')
except Exception as instance:
print type(instance)
exit()
allImage = []
bins = f.read()
index = 0
magicNum,imageNum,rowNum,colNum = struct.unpack_from('>IIII',bins,index)
index = index + struct.calcsize('>IIII')
assert 2051 == magicNum,'dataset damaged | little endian'
for ct in xrange(imageNum):
allImage.append(struct.unpack_from('>784B',bins,index))
index = index + struct.calcsize('>784B')
return numpy.array(allImage,dtype='float32')
def loadLabels(filename):
try:
f = open(filename,'rb')
except Exception as instance:
print type(instance)
exit()
allLabels = []
bins = f.read()
index = 0
magicNum,labelNum = struct.unpack_from('>II',bins,index)
index = index + struct.calcsize('>II')
assert 2049 == magicNum,'dataset damaged | little endian'
for ct in xrange(labelNum):
allLabels.append(struct.unpack_from('B',bins,index))
index = index + struct.calcsize('B')
return numpy.array(allLabels,dtype='float32')
if '__main__' == __name__:
images = loadImages('t10k-images.idx3-ubyte')
labels = loadLabels('t10k-labels.idx1-ubyte')
import matplotlib.pyplot as plt
for x in range(3):
plt.figure()
shown = images[x].reshape(28,28)
# shown 28*28 numpy matrix
plt.imshow(shown,cmap='gray')
plt.title(str(labels[x]))
plt.show()
读取NMIST
# -*- coding: utf-8 -*-
import dataLoad
import numpy as np
import sys
import warnings
def bp(trainSet,eta=0.01,nin=None,nhid=None,nout=None,iterNum = 10):
'''[(instance,label),((784,1)array,(10,1)array)……]'''
Wkh = (np.random.rand(nout,nhid)-0.5) / 10.0
Whi = (np.random.rand(nhid,nin )-0.5) / 10.0
iteration = 0
er = 1
'''iteration'''
while iteration < iterNum and er > 0.04:
print 'iteration=',iteration
er = testAnn((Whi,Wkh))
iteration += 1
for (x,label) in trainSet:
# 最大最小归一化
x = (x - x.min())/x.max()-x.min()
#前向
neth = np.dot(Whi,x) # nhid*nin nin*1 -> nhid*1
oh = sigmoid(neth) # nhid*1
netk = np.dot(Wkh,oh) #nout*nhid nhid*1 -> nout*1
ok = sigmoid(netk) # nout*1
#求误差
dk = ok*(1-ok)*(label-ok)
dh = oh*(1-oh)*np.dot(Wkh.T,dk) #(nhid,1) = (nout,nhid).T * (nout,1)
#更新权值矩阵
Wkh = Wkh + eta * np.dot(dk,oh.T) #nout*nhid + nout*1*1*hid
Whi = Whi + eta * np.dot(dh,x.T) #nhid*nin + nhid*1*1*nin
print 'iteration over'
return Wkh,Whi
def testAnn(model):
err = 0
for i in range(len(testLabels)):
res = fit(model,testImages.reshape(784,1))
if i > 9990:
print '\t ',int(testLabels[0]),'was recognized as',res[1]
if testLabels[0] != res[1]:
err += 1
errorRate = float(err)/float(len(testLabels))
print 'error rate',errorRate,'\n'
return errorRate
def fit(model,Image):
Whi,Wkh = model
ok = list(sigmoid(sigmoid(Image.T.dot(Whi.T)).dot(Wkh.T))[0])
return ok,ok.index(max(ok))
def sigmoid(y):
# 讨厌的溢出警告
warnings.filterwarnings("ignore")
return 1/(1+np.exp(-y))
if '__main__' == __name__:
np.random.seed(207)
trainImages = dataLoad.loadImages('train-images.idx3-ubyte')
trainLabels = dataLoad.loadLabels('train-labels.idx1-ubyte')
testImages = dataLoad.loadImages('t10k-images.idx3-ubyte')
testLabels = dataLoad.loadLabels('t10k-labels.idx1-ubyte')
dataSet = []
for i in range(len(trainLabels)):
tmp = np.zeros((10,1),dtype='float32')
tmp[int(trainLabels),0] = 1
dataSet.append((trainImages.reshape(784,1),tmp))
bp(trainSet=dataSet,eta=0.05,nin=784,nhid=20,nout=10,iterNum=20)
|
|