aboutsummaryrefslogtreecommitdiffhomepage
path: root/debian/upstream/scripts/count_wikipedia.py
blob: aacf8d0f60943e85609101b366973f44fcdc59da (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
#!/usr/bin/python

import sys
import os
import re
import codecs
import operator
import datetime
import nltk
import warnings

from unidecode import unidecode

def usage():
    print '''
tokenize a directory of text and count unigrams.

usage:
%s input_dir ../data/english_wikipedia.txt

input_dir is the root directory where sentence files live. Each file should contain
one sentence per line, with punctuation. This script will walk the directory recursively,
looking for text files. For each text file, it will tokenize each sentence into words and
add them to a global unigram count, outputted to output.txt of the form:

word count
word count
...

in descending order of count.

For speed, tokenization is done w/ Penn Treebank regexes via nltk's port:
http://www.cis.upenn.edu/~treebank/tokenizer.sed
http://www.nltk.org/api/nltk.tokenize.html#module-nltk.tokenize.treebank

For input sentences, this script allows for the format output by WikiExtractor.py
https://github.com/attardi/wikiextractor

That is,
- lines starting with <doc... are ignored
- lines starting with </doc> are ignored
- blank lines are ignored

To obtain wikipedia dumps, visit: https://dumps.wikimedia.org/enwiki
And download the file ending in '-pages-articles.xml.bz2'. This includes wikipedia pages
and articles but not previous revisions, edit history, and metadata.

Then run:
./WikiExtractor.py -o en_sents --no-templates enwiki-20151002-pages-articles.xml.bz2

''' % sys.argv[0]

SENTENCES_PER_BATCH = 500000 # after each batch, delete all counts with count == 1 (hapax legomena)
PRE_SORT_CUTOFF = 300        # before sorting, discard all words with less than this count

ALL_NON_ALPHA = re.compile(r'^[\W\d]*$', re.UNICODE)
SOME_NON_ALPHA = re.compile(r'[\W\d]', re.UNICODE)

class TopTokenCounter(object):
    def __init__(self):
        self.count = {}
        self.legomena = set()
        self.discarded = set()

    def add_tokens(self, tokens, split_hyphens=True):
        for token in tokens:
            # add eg 'marxist-leninist' as two tokens instead of one
            if split_hyphens and token.count('-') in [1, 2]:
                for subtoken in token.split('-'):
                    self.add_token(subtoken)
            else:
                self.add_token(token)

    def add_token(self, token):
        if not self.should_include(token):
            self.discarded.add(token)
            return
        token = self.normalize(token)
        if token in self.count:
            self.legomena.discard(token)
            self.count[token] += 1
        else:
            self.legomena.add(token)
            self.count[token] = 1

    def should_include(self, token):
        if len(token) < 2:
            return False
        if len(token) <= 2 and SOME_NON_ALPHA.search(token):
            # B., '', (), ...
            return False
        if ALL_NON_ALPHA.match(token):
            # 1,000, <<>>, ...
            return False
        if token.startswith('/'):
            # eg //en.wikipedia.org/wiki, /doc
            return False
        if token.endswith('='):
            # id=, title=, ...
            return False
        return True

    def normalize(self, token):
        return token.lower()

    def batch_prune(self):
        for token in self.legomena:
            del self.count[token]
        self.legomena = set()

    def pre_sort_prune(self):
        under_cutoff = set()
        for token, count in self.count.iteritems():
            if count < PRE_SORT_CUTOFF:
                under_cutoff.add(token)
        for token in under_cutoff:
            del self.count[token]
        self.legomena = set()

    def get_sorted_pairs(self):
        return sorted(self.count.items(), key=operator.itemgetter(1), reverse=True)

    def get_ts(self):
        return datetime.datetime.now().strftime("%b %d %Y %H:%M:%S")

    def get_stats(self):
        ts = self.get_ts()
        return "%s keys(count): %d" % (ts, len(self.count))

def main(input_dir_str, output_filename):
    counter = TopTokenCounter()
    print counter.get_ts(), 'starting...'
    lines = 0
    for root, dirs, files in os.walk(input_dir_str, topdown=True):
        if not files:
            continue
        for fname in files:
            path = os.path.join(root, fname)
            for line in codecs.open(path, 'r', 'utf8'):
                with warnings.catch_warnings():
                    # unidecode() occasionally (rarely but enough to clog terminal outout)
                    # complains about surrogate characters in some wikipedia sentences.
                    # ignore those warnings.
                    warnings.simplefilter('ignore')
                    line = unidecode(line)
                tokens = nltk.word_tokenize(line)
                counter.add_tokens(tokens)
                lines += 1
                if lines % SENTENCES_PER_BATCH == 0:
                    counter.batch_prune()
                    print counter.get_stats()
                    print 'processing: %s' % path
    print counter.get_stats()
    print 'deleting tokens under cutoff of', PRE_SORT_CUTOFF
    counter.pre_sort_prune()
    print 'done'
    print counter.get_stats()
    print counter.get_ts(), 'sorting...'
    sorted_pairs = counter.get_sorted_pairs()
    print counter.get_ts(), 'done'
    print 'writing...'
    with codecs.open(output_filename, 'w', 'utf8') as f:
        for token, count in sorted_pairs:
            f.write('%-18s %d\n' % (token, count))
    sys.exit(0)

if __name__ == '__main__':
    if len(sys.argv) != 3:
        usage()
        sys.exit(0)
    else:
        main(*sys.argv[1:])