|
1 # -*- coding: utf-8 -*-
2 '''
3 >>> c = Classy()
4 >>> c.train(['cpu', 'RAM', 'ALU', 'io', 'bridge', 'disk'], 'architecture')
5 True
6 >>> c.train(['monitor', 'mouse', 'keyboard', 'microphone', 'headphones'], 'input_devices')
7 True
8 >>> c.train(['desk', 'chair', 'cabinet', 'lamp'], 'office furniture')
9 True
10 >>> my_office = ['cpu', 'monitor', 'mouse', 'chair']
11 >>> c.classify(my_office)
12 ('input_devices', -1.0986122886681098)
13 ...
14 >>> c = Classy()
15 >>> c.train(['cpu', 'RAM', 'ALU', 'io', 'bridge', 'disk'], 'architecture')
16 True
17 >>> c.train(['monitor', 'mouse', 'keyboard', 'microphone', 'headphones'], 'input_devices')
18 True
19 >>> c.train(['desk', 'chair', 'cabinet', 'lamp'], 'office furniture')
20 True
21 >>> my_office = ['cpu', 'monitor', 'mouse', 'chair']
22 >>> c.classify(my_office)
23 ('input_devices', -1.0986122886681098)
24 ...
25 '''
26
27 from collections import Counter
28 import math
29
30 class ClassifierNotTrainedException(Exception):
31
32 def __str__(self):
33 return "Classifier is not trained."
34
35 class Classy(object):
36
37 def __init__(self):
38 self.term_count_store = {}
39 self.data = {
40 'class_term_count': {},
41 'beta_priors': {},
42 'class_doc_count': {},
43 }
44 self.total_term_count = 0
45 self.total_doc_count = 0
46
47 def train(self, document_source, class_id):
48
49 '''
50 Trains the classifier.
51
52 '''
53 count = Counter(document_source)
54 try:
55 self.term_count_store[class_id]
56 except KeyError:
57 self.term_count_store[class_id] = {}
58 for term in count:
59 try:
60 self.term_count_store[class_id][term] += count[term]
61 except KeyError:
62 self.term_count_store[class_id][term] = count[term]
63 try:
64 self.data['class_term_count'][class_id] += document_source.__len__()
65 except KeyError:
66 self.data['class_term_count'][class_id] = document_source.__len__()
67 try:
68 self.data['class_doc_count'][class_id] += 1
69 except KeyError:
70 self.data['class_doc_count'][class_id] = 1
71 self.total_term_count += document_source.__len__()
72 self.total_doc_count += 1
73 self.compute_beta_priors()
74 return True
75
76 def classify(self, document_input):
77 if not self.total_doc_count: raise ClassifierNotTrainedException()
78
79 term_freq_matrix = Counter(document_input)
80 arg_max_matrix = []
81 for class_id in self.data['class_doc_count']:
82 summation = 0
83 for term in document_input:
84 try:
85 conditional_probability = (self.term_count_store[class_id][term] + 1)
86 conditional_probability = conditional_probability / (self.data['class_term_count'][class_id] + self.total_doc_count)
87 summation += term_freq_matrix[term] * math.log(conditional_probability)
88 except KeyError:
89 break
90 arg_max = summation + self.data['beta_priors'][class_id]
91 arg_max_matrix.insert(0, (class_id, arg_max))
92 arg_max_matrix.sort(key=lambda x:x[1])
93 return (arg_max_matrix[-1][0], arg_max_matrix[-1][1])
94
95 def compute_beta_priors(self):
96 if not self.total_doc_count: raise ClassifierNotTrainedException()
97
98 for class_id in self.data['class_doc_count']:
99 tmp = self.data['class_doc_count'][class_id] / self.total_doc_count
100 self.data['beta_priors'][class_id] = math.log(tmp)
|
|