PyTorchText

1st Place Solution for Zhihu Machine Learning Challenge . Implementation of various text-classification models.(知乎看山杯第一名解决方案)

Github stars Tracking Chart

中文用户请查看readme-zh.md

This is the solution for Zhihu Machine Learning Challenge 2017. We won the champion out of 963 teams.

1. Setup

  • install PyTorch from pytorch.org (Python 2, CUDA)
  • install other depencies:
    pip2 install -r requirements.txt
    

You may need tf.contrib.keras.preprocessing.sequence.pad_sequences for data preprocessing.

  • start visdom for visualization:
    python2 -m visdom.server
    

2. Data Preprocessing

Modify the data path in the related file

2.1 wordvector file -> numpy file

python scripts/data_process/embedding2matrix.py main char_embedding.txt char_embedding.npz 
python scripts/data_process/embedding2matrix.py main word_embedding.txt word_embedding.npz 

2.2 question set -> numpy file

it's memory consuming , make sure you have memory larger than 32G.

python scripts/data_process/question2array.py main question_train_set.txt train.npz
python scripts/data_process/question2array.py main question_eval_set.txt test.npz

2.3 label -> json

python scripts/data_process/label2id.py main question_topic_train_set.txt labels.json

2.4 validation data

python scripts/data_process/get_val.py 

3. Training

modify config.py for model path

Path to the models we used:

  • CNN:models/MultiCNNTextBNDeep.py
  • RNN(LSTM):models/LSTMText.py
  • RCNN: models/RCNN.py
  • inception: models/CNNText_inception.py
  • FastText: models/FastText3.py

3.1 Trian model without data augumentation

# LSTM char
python2 main.py main --max_epoch=5 --plot_every=100 --env='lstm_char' --weight=1 --model='LSTMText'  --batch-size=128  --lr=0.001 --lr2=0 --lr_decay=0.5 --decay_every=10000  --type_='char'   --zhuge=True --linear-hidden-size=2000 --hidden-size=256 --kmax-pooling=3   --num-layers=3  --augument=False

# LSTM word
python2 main.py main --max_epoch=5 --plot_every=100 --env='lstm_word' --weight=1 --model='LSTMText'  --batch-size=128  --lr=0.001 --lr2=0.0000 --lr_decay=0.5 --decay_every=10000  --type_='word'   --zhuge=True --linear-hidden-size=2000 --hidden-size=320 --kmax-pooling=2  --augument=False

#  RCNN char
python2 main.py main --max_epoch=5 --plot_every=100 --env='rcnn_char' --weight=1 --model='RCNN'  --batch-size=128  --lr=0.001 --lr2=0 --lr_decay=0.5 --decay_every=5000  --title-dim=1024 --content-dim=1024  --type_='char' --zhuge=True --kernel-size=3 --kmax-pooling=2 --linear-hidden-size=2000 --debug-file='/tmp/debugrcnn' --hidden-size=256 --num-layers=3 --augument=False

# RCNN word
main.py main --max_epoch=5 --plot_every=100 --env='RCNN-word' --weight=1 --model='RCNN'  --zhuge=True --num-workers=4 --batch-size=128 --model-path=None --lr2=0 --lr=1e-3 --lr-decay=0.8  --decay-every=5000  --title-dim=1024 --content-dim=512  --kernel-size=3 --debug-file='/tmp/debugrc'  --kmax-pooling=1 --type_='word' --augument=False
# CNN word
 python main.py main --max_epoch=5 --plot_every=100 --env='MultiCNNText' --weight=1 --model='MultiCNNTextBNDeep'  --batch-size=64  --lr=0.001 --lr2=0.000 --lr_decay=0.8 --decay_every=10000  --title-dim=250 --content-dim=250    --weight-decay=0 --type_='word' --debug-file='/tmp/debug'  --linear-hidden-size=2000 --zhuge=True  --augument=False

# inception word
python2 main.py main --max_epoch=5 --plot_every=100 --env='inception-word' --weight=1 --model='CNNText_inception'  --zhuge=True --num-workers=4 --batch-size=512 --model-path=None --lr2=0 --lr=1e-3 --lr-decay=0.8  --decay-every=2500 --title-dim=1200 --content-dim=1200 --type_='word' --augument=False                                                   
# inception char
python2 main.py main --max_epoch=5 --plot_every=100 --env='inception-char' --weight=1 --model='CNNText_inception'  --zhuge=True --num-workers=4 --batch-size=512 --model-path=None --lr2=0 --lr=1e-3 --lr-decay=0.8  --decay-every=2500 --title-dim=1200 --content-dim=1200 --type_='char'   --augument=False

# FastText3 word
python2 main.py main --max_epoch=5 --plot_every=100 --env='fasttext3-word' --weight=5 --model='FastText3' --zhuge=True --num-workers=4 --batch-size=512  --lr2=1e-4 --lr=1e-3 --lr-decay=0.8  --decay-every=2500 --linear_hidden_size=2000 --type_='word'  --debug-file=/tmp/debugf --augument=False                           

In most cases, the score could be boosted by finetune. for example:

python2 main.py main --max_epoch=2 --plot_every=100 --env='LSTMText-word-ft' --model='LSTMText'  --zhuge=True --num-workers=4 --batch-size=256 --model-path=None --lr2=5e-5 --lr=5e-5 --decay-every=5000 --type_='word'  --model-path='checkpoints/LSTMText_word_0.409196378421'                       

3.2 train models with data augumentation

Add --augument in the training command.

3.3 scores, model, score, :---:, :----:

CNN_word, 0.4103
RNN_word, 0.4119
RCNN_word, 0.4115
Inceptin_word, 0.4109
FastText_word, 0.4091
RNN_char, 0.4031
RCNN_char, 0.4037
Inception_char, 0.4024
RCNN_word_aug, 0.41344
CNN_word_aug, 0.41051
RNN_word_aug, 0.41368
Incetpion_word_aug, 0.41254
FastText3_word_aug, 0.40853
CNN_char_aug, 0.38738
RCNN_char_aug, 0.39854

with model ensemble, it can get up to 0.433.

4 Test and Submit

4.1 Test

  • model: include LSTMText,RCNN,MultiCNNTextBNDeep,FastText3,CNNText_inception
  • model-path: path to the pretrained model
  • result-path: where to save the model
  • val: test the val set or the test set..
# LSTM
python2 test.1.py main --model='LSTMText'  --batch-size=512  --model-path='checkpoints/LSTMText_word_0.411994005382' --result-path='/data_ssd/zhihu/result/LSTMText0.4119_word_test.pth'  --val=False --zhuge=True

python2 test.1.py main --model='LSTMText'  --batch-size=256 --type_=char --model-path='checkpoints/LSTMText_char_0.403192339135' --result-path='/data_ssd/zhihu/result/LSTMText0.4031_char_test.pth'  --val=False --zhuge=True
 
#RCNN
python2 test.1.py main --model='RCNN'  --batch-size=512  --model-path='checkpoints/RCNN_word_0.411511574999' --result-path='/data_ssd/zhihu/result/RCNN_0.4115_word_test.pth'  --val=False --zhuge=True

python2 test.1.py main --model='RCNN'  --batch-size=512  --model-path='checkpoints/RCNN_char_0.403710422571' --result-path='/data_ssd/zhihu/result/RCNN_0.4037_char_test.pth'  --val=False --zhuge=True

# DeepText

python2 test.1.py main --model='MultiCNNTextBNDeep'  --batch-size=512  --model-path='checkpoints/MultiCNNTextBNDeep_word_0.410330780091' --result-path='/data_ssd/zhihu/result/DeepText0.4103_word_test.pth'  --val=False --zhuge=True
# more to go ...

4.2 ensemble

See notebooks/val_ensemble.ipynb and notebooks/test_ensemble.ipynb for more detail

5 Main files

  • main.py: main(for training)
  • config.py: config file
  • test.1.py: for test
  • data/: for data loader
  • scripts/: for data preprocessing
  • utils/ : including calculate score and wrapper for visualization.
  • models/: models
    • models/BasicModel: Base model for models.
    • models/MultiCNNTextBNDeep: CNN
    • models/LSTMText: RNN
    • models/RCNN: RCNN
    • models/CNNText_inception Inception
    • models/MultiModelALLmodels/MultiModelAll2
    • other model
  • rep.py: code for reproducing.
  • del/: methods fail or not used.
  • notebooks/: notebooks.

Pretrained model

https://pan.baidu.com/s/1mjVtJGs passwd: tayb

Main metrics

Overview
Name With Ownerdragen1860/TensorFlow-2.x-Tutorials
Primary LanguageJupyter Notebook
Program languagePython (Language Count: 2)
Platform
License:
所有者活动
Created At2019-01-30 03:00:09
Pushed At2020-09-23 15:50:30
Last Commit At2020-09-23 23:50:19
Release Count0
用户参与
Stargazers Count6.4k
Watchers Count227
Fork Count2.2k
Commits Count249
Has Issues Enabled
Issues Count38
Issue Open Count19
Pull Requests Count14
Pull Requests Open Count7
Pull Requests Close Count3
项目设置
Has Wiki Enabled
Is Archived
Is Fork
Is Locked
Is Mirror
Is Private