288 lines
11 KiB
Python
288 lines
11 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
#
|
|
# Copyright (C) 2016 Radim Rehurek <me@radimrehurek.com>
|
|
# Modifications (C) 2017 Hai Liang Wang <hailiang.hl.wang@gmail.com>
|
|
# Licensed under the GNU LGPL v3.0 - http://www.gnu.org/licenses/lgpl.html
|
|
# Author: Hai Liang Wang
|
|
# Date: 2017-10-16:14:13:24
|
|
#
|
|
#=========================================================================
|
|
|
|
from __future__ import print_function
|
|
from __future__ import division
|
|
|
|
__copyright__ = "Copyright (c) 2017 . All Rights Reserved"
|
|
__author__ = "Hai Liang Wang"
|
|
__date__ = "2017-10-16:14:13:24"
|
|
|
|
import os
|
|
import sys
|
|
curdir = os.path.dirname(os.path.abspath(__file__))
|
|
sys.path.append(curdir)
|
|
|
|
if sys.version_info[0] < 3:
|
|
reload(sys)
|
|
sys.setdefaultencoding("utf-8")
|
|
# raise "Must be using Python 3"
|
|
else:
|
|
xrange = range
|
|
|
|
from absl import logging
|
|
|
|
import utils
|
|
from numpy import dot, zeros, dtype, float32 as REAL,\
|
|
double, array, vstack, fromstring, sqrt, newaxis,\
|
|
ndarray, sum as np_sum, prod, ascontiguousarray,\
|
|
argmax
|
|
from sklearn.neighbors import KDTree
|
|
|
|
|
|
|
|
class Vocab(object):
|
|
"""
|
|
A single vocabulary item, used internally for collecting per-word frequency/sampling info,
|
|
and for constructing binary trees (incl. both word leaves and inner nodes).
|
|
"""
|
|
|
|
def __init__(self, **kwargs):
|
|
self.count = 0
|
|
self.__dict__.update(kwargs)
|
|
|
|
def __lt__(self, other): # used for sorting in a priority queue
|
|
return self.count < other.count
|
|
|
|
def __str__(self):
|
|
vals = [
|
|
'%s:%r' %
|
|
(key,
|
|
self.__dict__[key]) for key in sorted(
|
|
self.__dict__) if not key.startswith('_')]
|
|
return "%s(%s)" % (self.__class__.__name__, ', '.join(vals))
|
|
|
|
|
|
class KeyedVectors():
|
|
"""
|
|
Class to contain vectors and vocab for the Word2Vec training class and other w2v methods not directly
|
|
involved in training such as most_similar()
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.syn0 = []
|
|
self.syn0norm = None
|
|
self.vocab = {}
|
|
self.index2word = []
|
|
self.vector_size = None
|
|
self.kdt = None
|
|
|
|
@property
|
|
def wv(self):
|
|
return self
|
|
|
|
def save(self, *args, **kwargs):
|
|
# don't bother storing the cached normalized vectors
|
|
kwargs['ignore'] = kwargs.get('ignore', ['syn0norm'])
|
|
super(KeyedVectors, self).save(*args, **kwargs)
|
|
|
|
@classmethod
|
|
def load_word2vec_format(
|
|
cls,
|
|
fname,
|
|
fvocab=None,
|
|
binary=False,
|
|
encoding='utf8',
|
|
unicode_errors='strict',
|
|
limit=None,
|
|
datatype=REAL):
|
|
"""
|
|
Load the input-hidden weight matrix from the original C word2vec-tool format.
|
|
Note that the information stored in the file is incomplete (the binary tree is missing),
|
|
so while you can query for word similarity etc., you cannot continue training
|
|
with a model loaded this way.
|
|
`binary` is a boolean indicating whether the data is in binary word2vec format.
|
|
`norm_only` is a boolean indicating whether to only store normalised word2vec vectors in memory.
|
|
Word counts are read from `fvocab` filename, if set (this is the file generated
|
|
by `-save-vocab` flag of the original C tool).
|
|
If you trained the C model using non-utf8 encoding for words, specify that
|
|
encoding in `encoding`.
|
|
`unicode_errors`, default 'strict', is a string suitable to be passed as the `errors`
|
|
argument to the unicode() (Python 2.x) or str() (Python 3.x) function. If your source
|
|
file may include word tokens truncated in the middle of a multibyte unicode character
|
|
(as is common from the original word2vec.c tool), 'ignore' or 'replace' may help.
|
|
`limit` sets a maximum number of word-vectors to read from the file. The default,
|
|
None, means read all.
|
|
`datatype` (experimental) can coerce dimensions to a non-default float type (such
|
|
as np.float16) to save memory. (Such types may result in much slower bulk operations
|
|
or incompatibility with optimized routines.)
|
|
"""
|
|
counts = None
|
|
if fvocab is not None:
|
|
logging.debug("loading word counts from %s" % fvocab)
|
|
counts = {}
|
|
with utils.smart_open(fvocab) as fin:
|
|
for line in fin:
|
|
word, count = utils.to_unicode(line).strip().split()
|
|
counts[word] = int(count)
|
|
|
|
logging.debug("loading projection weights from %s" % fname)
|
|
with utils.smart_open(fname) as fin:
|
|
header = utils.to_unicode(fin.readline(), encoding=encoding)
|
|
# throws for invalid file format
|
|
vocab_size, vector_size = (int(x) for x in header.split())
|
|
if limit:
|
|
vocab_size = min(vocab_size, limit)
|
|
result = cls()
|
|
result.vector_size = vector_size
|
|
result.syn0 = zeros((vocab_size, vector_size), dtype=datatype)
|
|
|
|
def add_word(word, weights):
|
|
word_id = len(result.vocab)
|
|
# logging.debug("word id: %d, word: %s, weights: %s" % (word_id, word, weights))
|
|
if word in result.vocab:
|
|
logging.debug(
|
|
"duplicate word '%s' in %s, ignoring all but first" %
|
|
(word, fname))
|
|
return
|
|
if counts is None:
|
|
# most common scenario: no vocab file given. just make up
|
|
# some bogus counts, in descending order
|
|
result.vocab[word] = Vocab(
|
|
index=word_id, count=vocab_size - word_id)
|
|
elif word in counts:
|
|
# use count from the vocab file
|
|
result.vocab[word] = Vocab(
|
|
index=word_id, count=counts[word])
|
|
else:
|
|
# vocab file given, but word is missing -- set count to
|
|
# None (TODO: or raise?)
|
|
logging.debug(
|
|
"vocabulary file is incomplete: '%s' is missing" %
|
|
word)
|
|
result.vocab[word] = Vocab(index=word_id, count=None)
|
|
result.syn0[word_id] = weights
|
|
result.index2word.append(word)
|
|
|
|
if binary:
|
|
binary_len = dtype(REAL).itemsize * vector_size
|
|
for _ in xrange(vocab_size):
|
|
# mixed text and binary: read text first, then binary
|
|
word = []
|
|
while True:
|
|
ch = fin.read(1)
|
|
if ch == b' ':
|
|
break
|
|
if ch == b'':
|
|
raise EOFError(
|
|
"unexpected end of input; is count incorrect or file otherwise damaged?")
|
|
# ignore newlines in front of words (some binary files
|
|
# have)
|
|
if ch != b'\n':
|
|
word.append(ch)
|
|
word = utils.to_unicode(
|
|
b''.join(word), encoding=encoding, errors=unicode_errors)
|
|
weights = fromstring(fin.read(binary_len), dtype=REAL)
|
|
add_word(word, weights)
|
|
else:
|
|
for line_no in xrange(vocab_size):
|
|
line = fin.readline()
|
|
if line == b'':
|
|
raise EOFError(
|
|
"unexpected end of input; is count incorrect or file otherwise damaged?")
|
|
parts = utils.to_unicode(
|
|
line.rstrip(),
|
|
encoding=encoding,
|
|
errors=unicode_errors).split(" ")
|
|
if len(parts) != vector_size + 1:
|
|
raise ValueError(
|
|
"invalid vector on line %s (is this really the text format?)" %
|
|
line_no)
|
|
word, weights = parts[0], [REAL(x) for x in parts[1:]]
|
|
add_word(word, weights)
|
|
if result.syn0.shape[0] != len(result.vocab):
|
|
logging.debug(
|
|
"duplicate words detected, shrinking matrix size from %i to %i" %
|
|
(result.syn0.shape[0], len(result.vocab)))
|
|
result.syn0 = ascontiguousarray(result.syn0[: len(result.vocab)])
|
|
assert (len(result.vocab), vector_size) == result.syn0.shape
|
|
'''
|
|
KDTree
|
|
Build KDTree with vectors.
|
|
http://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KDTree.html#sklearn.neighbors.KDTree
|
|
'''
|
|
result.kdt = KDTree(result.syn0, leaf_size=10, metric = "euclidean")
|
|
logging.debug("loaded %s matrix from %s" % (result.syn0.shape, fname))
|
|
return result
|
|
|
|
def word_vec(self, word, use_norm=False):
|
|
"""
|
|
Accept a single word as input.
|
|
Returns the word's representations in vector space, as a 1D numpy array.
|
|
If `use_norm` is True, returns the normalized word vector.
|
|
Example::
|
|
>>> trained_model['office']
|
|
array([ -1.40128313e-02, ...])
|
|
"""
|
|
if word in self.vocab:
|
|
if use_norm:
|
|
result = self.syn0norm[self.vocab[word].index]
|
|
else:
|
|
result = self.syn0[self.vocab[word].index]
|
|
|
|
result.setflags(write=False)
|
|
return result
|
|
else:
|
|
raise KeyError("word '%s' not in vocabulary" % word)
|
|
|
|
def neighbours(self, word, size = 10):
|
|
"""
|
|
Get nearest words with KDTree, ranking by cosine distance
|
|
"""
|
|
word = word.strip()
|
|
v = self.word_vec(word)
|
|
[distances], [points] = self.kdt.query(array([v]), k = size, return_distance = True)
|
|
assert len(distances) == len(points), "distances and points should be in same shape."
|
|
words, scores = [], {}
|
|
for (x,y) in zip(points, distances):
|
|
w = self.index2word[x]
|
|
if w == word: s = 1.0
|
|
else: s = utils.cosine(v, self.syn0[x])
|
|
if s < 0: s = abs(s)
|
|
words.append(w)
|
|
scores[w] = min(s, 1.0)
|
|
for x in sorted(words, key=scores.get, reverse=True):
|
|
yield x, scores[x]
|
|
|
|
import unittest
|
|
|
|
# run testcase: python /Users/hain/tmp/ss Test.testExample
|
|
|
|
|
|
class Test(unittest.TestCase):
|
|
'''
|
|
|
|
'''
|
|
|
|
def setUp(self):
|
|
pass
|
|
|
|
def tearDown(self):
|
|
pass
|
|
|
|
def test_load_w2v_data(self):
|
|
_fin_wv_path = os.path.join(curdir, 'data', 'words.vector')
|
|
_fin_stopwords_path = os.path.join(curdir, 'data', 'stopwords.txt')
|
|
kv = KeyedVectors()
|
|
binary = True
|
|
kv.load_word2vec_format(
|
|
_fin_wv_path,
|
|
binary=binary,
|
|
unicode_errors='ignore')
|
|
|
|
|
|
def test():
|
|
unittest.main()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test()
|