diff --git a/question_classify.py b/question_classify.py index fa58ee7..0c1ac1f 100644 --- a/question_classify.py +++ b/question_classify.py @@ -112,8 +112,8 @@ class QuestionClassify(object): model.add(Conv1D(128, 3, activation='relu')) model.add(GlobalAveragePooling1D()) model.add(Dropout(0.5)) - model.add(Dense(13, activation='sigmoid')) - model.compile(loss='binary_crossentropy', + model.add(Dense(13, activation='softmax')) + model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy']) model.summary()