1 import gzip 2 import paddle.v2.dataset.flowers as flowers 3 import paddle.v2 as paddle 4 import reader 5 import vgg 6 import resnet 7 import alexnet 8 import googlenet 9 import argparse 10 11 DATA_DIM = 3 * 224 * 224 12 CLASS_DIM = 102 13 BATCH_SIZE = 128 main(): parser = argparse.ArgumentParser() 19 parser.add_argument( , , , , , , , ]) 23 args = parser.parse_args() paddle.init(use_gpu=True, trainer_count=7) 27 28 image = paddle.layer.data( , type=paddle.data_type.dense_vector(DATA_DIM)) 30 lbl = paddle.layer.data( , type=paddle.data_type.integer_value(CLASS_DIM)) 32 33 extra_layers = None 34 learning_rate = 0.01 : 36 out = alexnet.alexnet(image, class_dim=CLASS_DIM) : 38 out = vgg.vgg13(image, class_dim=CLASS_DIM) : 40 out = vgg.vgg16(image, class_dim=CLASS_DIM) : 42 out = vgg.vgg19(image, class_dim=CLASS_DIM) : 44 out = resnet.resnet_imagenet(image, class_dim=CLASS_DIM) 45 learning_rate = 0.1 : 47 out, out1, out2 = googlenet.googlenet(image, class_dim=CLASS_DIM) 48 loss1 = paddle.layer.cross_entropy_cost( 49 input=out1, label=lbl, coeff=0.3) 50 paddle.evaluator.classification_error(input=out1, label=lbl) 51 loss2 = paddle.layer.cross_entropy_cost( 52 input=out2, label=lbl, coeff=0.3) 53 paddle.evaluator.classification_error(input=out2, label=lbl) 54 extra_layers = [loss1, loss2] 55 56 cost = paddle.layer.classification_cost(input=out, label=lbl) parameters = paddle.parameters.create(cost) optimizer = paddle.optimizer.Momentum( 63 momentum=0.9, 64 regularization=paddle.optimizer.L2Regularization(rate=0.0005 * 65 BATCH_SIZE), 66 learning_rate=learning_rate / BATCH_SIZE, 67 learning_rate_decay_a=0.1, 68 learning_rate_decay_b=128000 * 35, , ) 70 71 train_reader = paddle.batch( 72 paddle.reader.shuffle( 73 flowers.train(), buf_size=1000), 77 batch_size=BATCH_SIZE) 78 test_reader = paddle.batch( 79 flowers.valid(), batch_size=BATCH_SIZE) trainer = paddle.trainer.SGD( 86 cost=cost, 87 parameters=parameters, 88 update_equation=optimizer, 89 extra_layers=extra_layers) event_handler(event): 93 if isinstance(event, paddle.event.EndIteration): 94 if event.batch_id % 1 == 0: % ( 96 event.pass_id, event.batch_id, event.cost, event.metrics) 97 if isinstance(event, paddle.event.EndPass): % event.pass_id, ) as f: 99 trainer.save_parameter_to_tar(f) 100 101 result = trainer.test(reader=test_reader) % (event.pass_id, result.metrics) 103 104 trainer.train( 105 reader=train_reader, num_passes=200, event_handler=event_handler) == : 109 main()
3.运行方式
1 python train.py googlenet
其中最后的googlenet是可选的网络模型,输入其他的网络模型,如alexnet、vgg3、vgg6等就可以用不同的网络结构来训练了。
用Tensorflow实现GoogLeNet