yanglgzh 发表于 2015-11-30 11:27:12

Python实现ORM

  ORM即把数据库中的一个数据表给映射到代码里的一个类上,表的字段对应着类的属性。将增删改查等基本操作封装为类对应的方法,从而写出更干净和更富有层次性的代码。
  以查询数据为例,原始的写法要Python代码sql混合,示例代码如下:





1 import MySQLdb
2 import os,sys
3
4 def main():
5   conn=MySQLdb.connect(host="localhost",port=3306,passwd='toor',user='root')
6   conn.select_db("xdyweb")
7   cursor=conn.cursor()
8   count=cursor.execute("select * from users")
9   result=cursor.fetchmany()
10   print(isinstance(result,tuple))
11   print(type(result))
12   print(len(result))
13   for i in result:
14         print(i)
15         for j in i:
16             print(j)
17   print("row count is %s"%count)
18   cursor.close()
19   conn.close()
20
21 if __name__=="__main__":
22   cp=os.path.abspath('.')
23   sys.path.append(cp)
24   main()
View Code  而我们现在想要实现的是类似这样的效果:





1 #查找:
2 u=user.get(id=1)
3 #添加
4 u=user(name='y',password='y',email='1@q.com')
5 u.insert()
View Code  实现思路是遍历Model的属性,得出要操作的字段,然后根据不同的操作要求(增,删,改,查)去动态生成不同的sql语句。






1 #coding:utf-8
2
3 #author:xudongyang
4
5 #19:25 2015/4/15
6
7 importlogging,time,sys,os,threading
8 import test as db
9 # logging.basicConfig(level=logging.INFO,format='%(asctime)s %(filename)s %(levelname)s %(message)s',datefmt='%a, %d %b %Y %H:%M:%S')
10 logging.basicConfig(level=logging.INFO)
11
12 class Field(object):
13   #映射数据表中一个字段的属性,包括字段名称,默认值,是否主键,可空,可更新,可插入,字段类型(varchar,text,Integer之类),字段顺序
14   _count=0#当前定义的字段是类的第几个字段
15   def __init__(self,**kw):
16         self.name = kw.get('name', None)
17         self._default = kw.get('default', None)
18         self.primary_key = kw.get('primary_key', False)
19         self.nullable = kw.get('nullable', False)
20         self.updatable = kw.get('updatable', True)
21         self.insertable = kw.get('insertable', True)
22         self.ddl = kw.get('ddl', '')
23         self._order = Field._count
24         Field._count = Field._count + 1
25   @property
26   def default(self):
27         d = self._default
28         return d() if callable(d) else d
29
30 class StringField(Field):
31   #继承自Field,
32   def __init__(self, **kw):
33         if not 'default' in kw:
34             kw['default'] = ''
35         if not 'ddl' in kw:
36             kw['ddl'] = 'varchar(255)'
37         super(StringField, self).__init__(**kw)
38
39 class IntegerField(Field):
40
41   def __init__(self, **kw):
42         if not 'default' in kw:
43             kw['default'] = 0
44         if not 'ddl' in kw:
45             kw['ddl'] = 'bigint'
46         super(IntegerField, self).__init__(**kw)
47 class FloatField(Field):
48
49   def __init__(self, **kw):
50         if not 'default' in kw:
51             kw['default'] = 0.0
52         if not 'ddl' in kw:
53             kw['ddl'] = 'real'
54         super(FloatField, self).__init__(**kw)
55
56 class BooleanField(Field):
57
58   def __init__(self, **kw):
59         if not 'default' in kw:
60             kw['default'] = False
61         if not 'ddl' in kw:
62             kw['ddl'] = 'bool'
63         super(BooleanField, self).__init__(**kw)
64
65 class TextField(Field):
66
67   def __init__(self, **kw):
68         if not 'default' in kw:
69             kw['default'] = ''
70         if not 'ddl' in kw:
71             kw['ddl'] = 'text'
72         super(TextField, self).__init__(**kw)
73
74 class BlobField(Field):
75
76   def __init__(self, **kw):
77         if not 'default' in kw:
78             kw['default'] = ''
79         if not 'ddl' in kw:
80             kw['ddl'] = 'blob'
81         super(BlobField, self).__init__(**kw)
82
83 class VersionField(Field):
84
85   def __init__(self, name=None):
86         super(VersionField, self).__init__(name=name, default=0, ddl='bigint')
87
88 def _gen_sql(table_name, mappings):
89   print(__name__+'is called'+str(time.time()))
90   pk = None
91   sql = ['-- generating SQL for %s:' % table_name, 'create table `%s` (' % table_name]
92   for f in sorted(mappings.values(), lambda x, y: cmp(x._order, y._order)):
93         if not hasattr(f, 'ddl'):
94             raise StandardError('no ddl in field "%s".' % n)
95         ddl = f.ddl
96         nullable = f.nullable
97         if f.primary_key:
98             pk = f.name
99         sql.append(nullable and '`%s` %s,' % (f.name, ddl) or '`%s` %s not null,' % (f.name, ddl))
100   sql.append('primary key(`%s`)' % pk)
101   sql.append(');')
102   sql='\n'.join(sql)
103   logging.info('sql is :'+sql)
104   return sql
105
106 class ModelMetaClass(type):
107   #为什么__new__方法会被调用两次
108   #为什么attrs.pop(k)要进行这个,而且进行了之后u.name就可以输出yy而不是一个Field对象
109   def __new__(cls,name,base,attrs):
110         logging.info("cls is:"+str(cls))
111         logging.info("name is:"+str(name))
112         logging.info("base is:"+str(base))
113         logging.info("attrs is:"+str(attrs))
114         print('new is called at '+str(cls)+str(time.time()))
115
116         if name =="Model":
117             return type.__new__(cls,name,base,attrs)
118         mapping=dict()
119         primary_key=None
120         for k,v in attrs.iteritems():
121             primary_key=None
122             if isinstance(v,Field):
123               if not v.name:
124                     v.name=k
125               mapping=v
126               #检测是否是主键
127               if v.primary_key:
128                     if primary_key:
129                         raise TypeError("There only should be on primary_key")
130                     if v.updatable:
131                         logging.warning('primary_key should not be changed')
132                         v.updatable=False
133                     if v.nullable:
134                         logging.warning('pri.. not be.null')
135                         v.nullable=False
136                     primary_key=v
137
138         for k in mapping.iterkeys():
139             attrs.pop(k)
140
141         attrs['__mappings__']=mapping
142         logging.info('mapping is :'+str(mapping))
143         attrs['__primary_key__']=primary_key
144         attrs['__sql__']=lambda self: _gen_sql(attrs['__table__'], mapping)
145         return type.__new__(cls,name,base,attrs)
146 class ModelMetaclass(type):
147   '''
148   Metaclass for model objects.
149   '''
150   def __new__(cls, name, bases, attrs):
151         # skip base Model class:
152         if name=='Model':
153             return type.__new__(cls, name, bases, attrs)
154
155         # store all subclasses info:
156         if not hasattr(cls, 'subclasses'):
157             cls.subclasses = {}
158         if not name in cls.subclasses:
159             cls.subclasses = name
160         else:
161             logging.warning('Redefine class: %s' % name)
162
163         logging.info('Scan ORMapping %s...' % name)
164         mappings = dict()
165         primary_key = None
166         for k, v in attrs.iteritems():
167             if isinstance(v, Field):
168               if not v.name:
169                     v.name = k
170               logging.info('Found mapping: %s => %s' % (k, v))
171               # check duplicate primary key:
172               if v.primary_key:
173                     if primary_key:
174                         raise TypeError('Cannot define more than 1 primary key in class: %s' % name)
175                     if v.updatable:
176                         logging.warning('NOTE: change primary key to non-updatable.')
177                         v.updatable = False
178                     if v.nullable:
179                         logging.warning('NOTE: change primary key to non-nullable.')
180                         v.nullable = False
181                     primary_key = v
182               mappings = v
183         # check exist of primary key:
184         if not primary_key:
185             raise TypeError('Primary key not defined in class: %s' % name)
186         for k in mappings.iterkeys():
187             attrs.pop(k)
188         if not '__table__' in attrs:
189             attrs['__table__'] = name.lower()
190         attrs['__mappings__'] = mappings
191         attrs['__primary_key__'] = primary_key
192         attrs['__sql__'] = lambda self: _gen_sql(attrs['__table__'], mappings)
193         # for trigger in _triggers:
194         #   if not trigger in attrs:
195         #         attrs = None
196         return type.__new__(cls, name, bases, attrs)
197 class Model(dict):
198   __metaclass__ = ModelMetaClass
199   def __init__(self, **kw):
200         super(Model, self).__init__(**kw)
201
202   def __getattr__(self, key):
203         try:
204             return self
205         except KeyError:
206             raise AttributeError(r"'Dict' object has no attribute '%s'" % key)
207
208   def __setattr__(self, key, value):
209         self = value
210
211   @classmethod
212   def get(cls, pk):
213         '''
214         Get by primary key.
215         '''
216         d = db.select_one('select * from %s where %s=?' % (cls.__table__, cls.__primary_key__.name), pk)
217         return cls(**d) if d else None
218
219   @classmethod
220   def find_first(cls, where, *args):
221         '''
222         Find by where clause and return one result. If multiple results found,
223         only the first one returned. If no result found, return None.
224         '''
225         d = db.select_one('select * from %s %s' % (cls.__table__, where), *args)
226         return cls(**d) if d else None
227
228   @classmethod
229   def find_all(cls, *args):
230         '''
231         Find all and return list.
232         '''
233         L = db.select('select * from `%s`' % cls.__table__)
234         return
235
236   @classmethod
237   def find_by(cls, where, *args):
238         '''
239         Find by where clause and return list.
240         '''
241         L = db.select('select * from `%s` %s' % (cls.__table__, where), *args)
242         return
243
244   @classmethod
245   def count_all(cls):
246         '''
247         Find by 'select count(pk) from table' and return integer.
248         '''
249         return db.select_int('select count(`%s`) from `%s`' % (cls.__primary_key__.name, cls.__table__))
250
251   @classmethod
252   def count_by(cls, where, *args):
253         '''
254         Find by 'select count(pk) from table where ... ' and return int.
255         '''
256         return db.select_int('select count(`%s`) from `%s` %s' % (cls.__primary_key__.name, cls.__table__, where), *args)
257
258   def update(self):
259         self.pre_update and self.pre_update()
260         L = []
261         args = []
262         for k, v in self.__mappings__.iteritems():
263             if v.updatable:
264               if hasattr(self, k):
265                     arg = getattr(self, k)
266               else:
267                     arg = v.default
268                     setattr(self, k, arg)
269               L.append('`%s`=?' % k)
270               args.append(arg)
271         pk = self.__primary_key__.name
272         args.append(getattr(self, pk))
273         db.update('update `%s` set %s where %s=?' % (self.__table__, ','.join(L), pk), *args)
274         return self
275
276   def delete(self):
277         self.pre_delete and self.pre_delete()
278         pk = self.__primary_key__.name
279         args = (getattr(self, pk), )
280         db.update('delete from `%s` where `%s`=?' % (self.__table__, pk), *args)
281         return self
282
283   def insert(self):
284         self.pre_insert and self.pre_insert()
285         params = {}
286         for k, v in self.__mappings__.iteritems():
287             if v.insertable:
288               if not hasattr(self, k):
289                     setattr(self, k, v.default)
290               params = getattr(self, k)
291         db.insert('%s' % self.__table__, **params)
292         return self
293 class user(Model):
294   name=StringField(name='name',primary_key=True)
295   password=StringField(name='password')
296
297 def main():
298   u=user(name='yy',password='yyp')
299
300   logging.info(u.__sql__)
301   logging.info(dir(u.__mappings__.values()))
302   u.password='xxx'
303   print(u.password)
304
305 if __name__ == '__main__':
306   main()
View Code  
  要注意的是遍历Model属性这部分代码,利用了Python的__metaclass__实现,截断了Model的创建过程,进而对Model的属性进行遍历,具体代码见ModelMetaclass的__new__方法实现。
  这是模仿廖老师的代码,,感谢。还有两个疑问注释在了代码中,希望有看明白的人解惑。
  
页: [1]
查看完整版本: Python实现ORM