设为首页 收藏本站
查看: 857|回复: 0

[经验分享] Boltzmann机神经网络python实现

[复制链接]

尚未签到

发表于 2015-12-1 12:32:11 | 显示全部楼层 |阅读模式
  (python 3)
  



  1 import numpy
  2 from scipy import sparse as S
  3 from matplotlib import pyplot as plt
  4 from scipy.sparse.csr import csr_matrix
  5 import pandas
  6
  7 def normalize(x):
  8     V = x.copy()
  9     V -= x.min(axis=1).reshape(x.shape[0],1)
10     V /= V.max(axis=1).reshape(x.shape[0],1)
11     return V
12     
13 def sigmoid(x):
14     #return x*(x > 0)
15     #return numpy.tanh(x)
16     return 1.0/(1+numpy.exp(-x))
17
18 class RBM():
19     def __init__(self, n_visible=None, n_hidden=None, W=None, learning_rate = 0.1, weight_decay=1,cd_steps=1,momentum=0.5):
20         if W == None:
21             self.W =  numpy.random.uniform(-.1,0.1,(n_visible,  n_hidden)) / numpy.sqrt(n_visible + n_hidden)
22             self.W = numpy.insert(self.W, 0, 0, axis = 1)
23             self.W = numpy.insert(self.W, 0, 0, axis = 0)
24         else:
25             self.W=W
26         self.learning_rate = learning_rate
27         self.momentum = momentum
28         self.last_change = 0
29         self.last_update = 0
30         self.cd_steps = cd_steps
31         self.epoch = 0
32         self.weight_decay = weight_decay  
33         self.Errors = []
34         
35            
36     def fit(self, Input, max_epochs = 1, batch_size=100):  
37         if isinstance(Input, S.csr_matrix):
38             bias = S.csr_matrix(numpy.ones((Input.shape[0], 1)))
39             csr = S.hstack([bias, Input]).tocsr()
40         else:
41             csr = numpy.insert(Input, 0, 1, 1)
42         for epoch in range(max_epochs):
43             idx = numpy.arange(csr.shape[0])
44             numpy.random.shuffle(idx)
45             idx = idx[:batch_size]  
46                    
47             self.V_state = csr[idx]
48             self.H_state = self.activate(self.V_state)
49             pos_associations = self.V_state.T.dot(self.H_state)
50   
51             for i in range(self.cd_steps):
52               self.V_state = self.sample(self.H_state)  
53               self.H_state = self.activate(self.V_state)
54               
55             neg_associations = self.V_state.T.dot(self.H_state)
56             self.V_state = self.sample(self.H_state)
57            
58             # Update weights.
59             w_update = self.learning_rate * ((pos_associations - neg_associations) / batch_size)
60             total_change = numpy.sum(numpy.abs(w_update))
61             self.W += self.momentum * self.last_change  + w_update
62             self.W *= self.weight_decay
63            
64             self.last_change = w_update
65            
66             RMSE = numpy.mean((csr[idx] - self.V_state)**2)**0.5
67             self.Errors.append(RMSE)
68             self.epoch += 1
69             print("Epoch %s: RMSE = %s; ||W||: %6.1f; Sum Update: %f" % (self.epoch, RMSE, numpy.sum(numpy.abs(self.W)), total_change))  
70         return self
71         
72     def learning_curve(self):
73         plt.ion()
74         #plt.figure()
75         plt.show()
76         E = numpy.array(self.Errors)
77         plt.plot(pandas.rolling_mean(E, 50)[50:])  
78      
79     def activate(self, X):
80         if X.shape[1] != self.W.shape[0]:
81             if isinstance(X, S.csr_matrix):
82                 bias = S.csr_matrix(numpy.ones((X.shape[0], 1)))
83                 csr = S.hstack([bias, X]).tocsr()
84             else:
85                 csr = numpy.insert(X, 0, 1, 1)
86         else:
87             csr = X
88         p = sigmoid(csr.dot(self.W))
89         p[:,0]  = 1.0
90         return p  
91         
92     def sample(self, H, addBias=True):
93         if H.shape[1] == self.W.shape[0]:
94             if isinstance(H, S.csr_matrix):
95                 bias = S.csr_matrix(numpy.ones((H.shape[0], 1)))
96                 csr = S.hstack([bias, H]).tocsr()
97             else:
98                 csr = numpy.insert(H, 0, 1, 1)
99         else:
100             csr = H
101         p = sigmoid(csr.dot(self.W.T))
102         p[:,0] = 1
103         return p
104      
105 if __name__=="__main__":
106     data = numpy.random.uniform(0,1,(100,10))
107     rbm = RBM(10,15)
108     rbm.fit(data,1000)
109     rbm.learning_curve()
  

运维网声明 1、欢迎大家加入本站运维交流群:群②:261659950 群⑤:202807635 群⑦870801961 群⑧679858003
2、本站所有主题由该帖子作者发表,该帖子作者与运维网享有帖子相关版权
3、所有作品的著作权均归原作者享有,请您和我们一样尊重他人的著作权等合法权益。如果您对作品感到满意,请购买正版
4、禁止制作、复制、发布和传播具有反动、淫秽、色情、暴力、凶杀等内容的信息,一经发现立即删除。若您因此触犯法律,一切后果自负,我们对此不承担任何责任
5、所有资源均系网友上传或者通过网络收集,我们仅提供一个展示、介绍、观摩学习的平台,我们不对其内容的准确性、可靠性、正当性、安全性、合法性等负责,亦不承担任何法律责任
6、所有作品仅供您个人学习、研究或欣赏,不得用于商业或者其他用途,否则,一切后果均由您自己承担,我们对此不承担任何法律责任
7、如涉及侵犯版权等问题,请您及时通知我们,我们将立即采取措施予以解决
8、联系人Email:admin@iyunv.com 网址:www.yunweiku.com

所有资源均系网友上传或者通过网络收集,我们仅提供一个展示、介绍、观摩学习的平台,我们不对其承担任何法律责任,如涉及侵犯版权等问题,请您及时通知我们,我们将立即处理,联系人Email:kefu@iyunv.com,QQ:1061981298 本贴地址:https://www.yunweiku.com/thread-145863-1-1.html 上篇帖子: 2015/11/1用Python写游戏,pygame入门(1):pygame的安装 下篇帖子: python学习:基础概念
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

扫码加入运维网微信交流群X

扫码加入运维网微信交流群

扫描二维码加入运维网微信交流群,最新一手资源尽在官方微信交流群!快快加入我们吧...

扫描微信二维码查看详情

客服E-mail:kefu@iyunv.com 客服QQ:1061981298


QQ群⑦:运维网交流群⑦ QQ群⑧:运维网交流群⑧ k8s群:运维网kubernetes交流群


提醒:禁止发布任何违反国家法律、法规的言论与图片等内容;本站内容均来自个人观点与网络等信息,非本站认同之观点.


本站大部分资源是网友从网上搜集分享而来,其版权均归原作者及其网站所有,我们尊重他人的合法权益,如有内容侵犯您的合法权益,请及时与我们联系进行核实删除!



合作伙伴: 青云cloud

快速回复 返回顶部 返回列表