Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def path(self):
return self._path


class Model(object):
class NNModel(object):
""" Abstract class for Keras models.

Parameters:
Expand Down
18 changes: 10 additions & 8 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@

import numpy as np

from keras.models import Sequential
from keras.layers.core import Merge, Reshape, Dense
from keras.models import Sequential, Model
from keras.layers.core import Reshape, Dense
from keras.layers.merge import Dot
from keras.layers.embeddings import Embedding
from keras.preprocessing import sequence, text

from helper import CorpusReader, Model
from helper import CorpusReader, NNModel

FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
logging.basicConfig(format=FORMAT, level=logging.DEBUG)


class WordContextModel(Model):
class WordContextModel(NNModel):
""" WordContextModel implements a word2vec-like model trained with
negative sampling. The word indices (of both the target word and
its context) are passed to two distinct networks and the goal is to
Expand Down Expand Up @@ -62,10 +63,11 @@ def _prepare_model(vocab_size, vector_dim, loss_function,
context.add(Reshape((vector_dim, 1)))

logging.info('Building composite graph')
model = Sequential()
model.add(Merge([word, context], mode='dot', dot_axes=1))
model.add(Reshape((1, )))
model.add(Dense(1, activation='sigmoid'))
dot_merged = Dot(axes=1)([word.output, context.output])
composite = Reshape((1, ))(dot_merged)
composite = Dense(1, activation='sigmoid')(composite)

model = Model([word.input, context.input], composite)
model.compile(loss=loss_function, optimizer=optimizer)
return model

Expand Down