设为首页 收藏本站
查看: 1366|回复: 0

[经验分享] BP算法 python实现

[复制链接]
累计签到:1 天
连续签到:1 天
发表于 2015-12-1 11:26:45 | 显示全部楼层 |阅读模式
  0-9数字识别,NMIST数据的识别。
  具体代码包括NMIST见附件中。
  参考资料是TOM的机器学习BP那一章。
  


DSC0000.gif DSC0001.gif


# coding:utf-8
# 没考虑大小端
import struct
import numpy
def loadImages(filename):
try:
f = open(filename,'rb')
except Exception as instance:
print type(instance)
exit()
allImage = []
bins = f.read()
index = 0
magicNum,imageNum,rowNum,colNum = struct.unpack_from('>IIII',bins,index)
index = index + struct.calcsize('>IIII')
assert 2051 == magicNum,'dataset damaged | little endian'
for ct in xrange(imageNum):
allImage.append(struct.unpack_from('>784B',bins,index))
index = index + struct.calcsize('>784B')
return numpy.array(allImage,dtype='float32')
def loadLabels(filename):
try:
f = open(filename,'rb')
except Exception as instance:
print type(instance)
exit()
allLabels = []
bins = f.read()
index = 0
magicNum,labelNum = struct.unpack_from('>II',bins,index)
index = index + struct.calcsize('>II')
assert 2049 == magicNum,'dataset damaged | little endian'
for ct in xrange(labelNum):
allLabels.append(struct.unpack_from('B',bins,index))
index = index + struct.calcsize('B')
return numpy.array(allLabels,dtype='float32')


if '__main__' == __name__:
images = loadImages('t10k-images.idx3-ubyte')
labels = loadLabels('t10k-labels.idx1-ubyte')
import matplotlib.pyplot as plt
for x in range(3):
plt.figure()
shown = images[x].reshape(28,28)
# shown 28*28 numpy matrix
plt.imshow(shown,cmap='gray')
plt.title(str(labels[x]))
plt.show()





读取NMIST


# -*- coding: utf-8 -*-
import dataLoad
import numpy as np
import sys
import warnings
def bp(trainSet,eta=0.01,nin=None,nhid=None,nout=None,iterNum = 10):
'''[(instance,label),((784,1)array,(10,1)array)……]'''
Wkh = (np.random.rand(nout,nhid)-0.5) / 10.0
Whi = (np.random.rand(nhid,nin )-0.5) / 10.0
iteration = 0
er = 1
'''iteration'''
while iteration < iterNum and er > 0.04:
print 'iteration=',iteration
er = testAnn((Whi,Wkh))
iteration += 1
for (x,label) in trainSet:
# 最大最小归一化
x = (x - x.min())/x.max()-x.min()
#前向
neth = np.dot(Whi,x) # nhid*nin nin*1 -> nhid*1
oh = sigmoid(neth)   # nhid*1
netk = np.dot(Wkh,oh) #nout*nhid nhid*1 -> nout*1
ok = sigmoid(netk) # nout*1
#求误差
dk = ok*(1-ok)*(label-ok)
dh = oh*(1-oh)*np.dot(Wkh.T,dk) #(nhid,1) = (nout,nhid).T * (nout,1)
#更新权值矩阵
Wkh = Wkh + eta * np.dot(dk,oh.T)  #nout*nhid + nout*1*1*hid
Whi = Whi + eta * np.dot(dh,x.T)  #nhid*nin + nhid*1*1*nin
print 'iteration over'
return Wkh,Whi
def testAnn(model):
err = 0   
for i in range(len(testLabels)):
res = fit(model,testImages.reshape(784,1))
if i > 9990:
print '\t ',int(testLabels[0]),'was recognized as',res[1]
if testLabels[0] != res[1]:
err += 1
errorRate = float(err)/float(len(testLabels))
print 'error rate',errorRate,'\n'
return errorRate
def fit(model,Image):
Whi,Wkh = model
ok = list(sigmoid(sigmoid(Image.T.dot(Whi.T)).dot(Wkh.T))[0])
return ok,ok.index(max(ok))
def sigmoid(y):
# 讨厌的溢出警告
warnings.filterwarnings("ignore")
return 1/(1+np.exp(-y))   
if '__main__' == __name__:
np.random.seed(207)
trainImages = dataLoad.loadImages('train-images.idx3-ubyte')
trainLabels = dataLoad.loadLabels('train-labels.idx1-ubyte')
testImages = dataLoad.loadImages('t10k-images.idx3-ubyte')
testLabels = dataLoad.loadLabels('t10k-labels.idx1-ubyte')
dataSet = []
for i in range(len(trainLabels)):
tmp = np.zeros((10,1),dtype='float32')
tmp[int(trainLabels),0] = 1
dataSet.append((trainImages.reshape(784,1),tmp))
bp(trainSet=dataSet,eta=0.05,nin=784,nhid=20,nout=10,iterNum=20)

  

运维网声明 1、欢迎大家加入本站运维交流群:群②:261659950 群⑤:202807635 群⑦870801961 群⑧679858003
2、本站所有主题由该帖子作者发表,该帖子作者与运维网享有帖子相关版权
3、所有作品的著作权均归原作者享有,请您和我们一样尊重他人的著作权等合法权益。如果您对作品感到满意,请购买正版
4、禁止制作、复制、发布和传播具有反动、淫秽、色情、暴力、凶杀等内容的信息,一经发现立即删除。若您因此触犯法律,一切后果自负,我们对此不承担任何责任
5、所有资源均系网友上传或者通过网络收集,我们仅提供一个展示、介绍、观摩学习的平台,我们不对其内容的准确性、可靠性、正当性、安全性、合法性等负责,亦不承担任何法律责任
6、所有作品仅供您个人学习、研究或欣赏,不得用于商业或者其他用途,否则,一切后果均由您自己承担,我们对此不承担任何法律责任
7、如涉及侵犯版权等问题,请您及时通知我们,我们将立即采取措施予以解决
8、联系人Email:admin@iyunv.com 网址:www.yunweiku.com

所有资源均系网友上传或者通过网络收集,我们仅提供一个展示、介绍、观摩学习的平台,我们不对其承担任何法律责任,如涉及侵犯版权等问题,请您及时通知我们,我们将立即处理,联系人Email:kefu@iyunv.com,QQ:1061981298 本贴地址:https://www.yunweiku.com/thread-145811-1-1.html 上篇帖子: pyDes库 实现python的des加密 下篇帖子: Python + OpenCV2 系列:1
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

扫码加入运维网微信交流群X

扫码加入运维网微信交流群

扫描二维码加入运维网微信交流群,最新一手资源尽在官方微信交流群!快快加入我们吧...

扫描微信二维码查看详情

客服E-mail:kefu@iyunv.com 客服QQ:1061981298


QQ群⑦:运维网交流群⑦ QQ群⑧:运维网交流群⑧ k8s群:运维网kubernetes交流群


提醒:禁止发布任何违反国家法律、法规的言论与图片等内容;本站内容均来自个人观点与网络等信息,非本站认同之观点.


本站大部分资源是网友从网上搜集分享而来,其版权均归原作者及其网站所有,我们尊重他人的合法权益,如有内容侵犯您的合法权益,请及时与我们联系进行核实删除!



合作伙伴: 青云cloud

快速回复 返回顶部 返回列表