Skip to content

Commit 8510560

Browse files
authored
update docs (#5180)
* update codes for senta benchmark
1 parent cb76056 commit 8510560

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

PaddleNLP/examples/text_classification/pretrained_models/predict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def predict(model, data, tokenizer, label_map, batch_size=1):
129129
examples = []
130130
for text in data:
131131
input_ids, segment_ids = convert_example(
132-
[text],
132+
text,
133133
tokenizer,
134134
label_list=label_map.values(),
135135
max_seq_length=args.max_seq_length,

PaddleNLP/examples/text_classification/rnn/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,13 @@ wget https://paddlenlp.bj.bcebos.com/data/senta_word_dict.txt
153153
CPU 启动:
154154

155155
```shell
156-
python train.py --vocab_path='./senta_word_dict.txt' --use_gpu=False --network=bilstm --lr=5e-4 --batch_size=64 --epochs=5 --save_dir='./checkpoints'
156+
python train.py --vocab_path='./senta_word_dict.txt' --use_gpu=False --network=bilstm --lr=5e-4 --batch_size=64 --epochs=10 --save_dir='./checkpoints'
157157
```
158158

159159
GPU 启动:
160160

161161
```shell
162-
# CUDA_VISIBLE_DEVICES=0 python train.py --vocab_path='./senta_word_dict.txt' --use_gpu=True --network=bilstm --lr=5e-4 --batch_size=64 --epochs=5 --save_dir='./checkpoints'
162+
CUDA_VISIBLE_DEVICES=0 python train.py --vocab_path='./senta_word_dict.txt' --use_gpu=True --network=bilstm --lr=5e-4 --batch_size=64 --epochs=10 --save_dir='./checkpoints'
163163
```
164164

165165
以上参数表示:

PaddleNLP/examples/text_classification/rnn/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,12 @@ def create_dataloader(dataset,
160160
print("Loaded checkpoint from %s" % args.init_from_ckpt)
161161

162162
# Starts training and evaluating.
163+
callback = paddle.callbacks.ProgBarLogger(log_freq=10, verbose=3)
163164
model.fit(train_loader,
164165
dev_loader,
165166
epochs=args.epochs,
166-
save_dir=args.save_dir)
167+
save_dir=args.save_dir,
168+
callbacks=callback)
167169

168170
# Finally tests model.
169171
results = model.evaluate(test_loader)

0 commit comments

Comments
 (0)