业界 | 扒一扒Facebook人工智能谈判模型 — Facebook开源的”端到端”强化学习模型

端到端学习是那么吸引人, 因为它与理想的”自治”学习是那么近. — David 9

我们离完全”自治”的AI系统还很远很远, 没有自我采集样本的能力, 没有自己构建模型的能力, 也缺少”端到端” 学什么就像什么的灵活性. 而最近Facebook 人工智能研究所(FAIR)的研究人员公开了一个具有谈判新能力的对话智能体(dialog agents),并开源了其代码, 在”端到端” 这一方向上更进了一步:

这篇文章的突破仅限于智能对话, 更像是一篇专利, 教大家如何用一堆神经网络训练一个智能对话来获得谈判最终利益. 另外值得注意的是该pytorch项目虽然开源, 但是是经过 creativecommonsNonCommercial 4.0 非商业化协议保护的, 即, 你可以研究和使用代码, 但是你不能直接用它做商业用途.

言归正传, David 9 想说的是, 这个近乎科幻的对话机器人, 其实并没有那么神奇.

首先看看Facebook一伙人怎么收集对话(dialog)数据的 :

Facebook这伙人收集的数据是从亚马逊 Mechanical Turk 交易网站上 买来的, $0.15一个对话, 总共买了5808个对话.

瞧, 我们还在一个”茹毛饮血” 用钱买数据的时代.

想要赚$0.15一个对话很简单, 如上图, 你只要根据每个物品的计分(书本得0分, 帽子得7分, 篮球得1分) 和对方谈判需要的物件, 尽量获得最大总得分, 谈判成交, 即可 .

在如此简单的实验环境和语料集合下(只有书本, 帽子, 篮球等简单物件), 作者证明了我们可以训练神经网络使得它具有一定的谈判能力.

总体训练方式也比较简单:

如上图, 把完整的对话分割成两个实验对象(Agent 1 和 Agent 2) 各自角度的独立对话. 这样, 输入就是每个实验对象的谈判目标goal ( 即, 各个物件价值, 如:  “书本得3分, 帽子得2分, 篮球得1分”). (紫色部分) 训练数据还包含整体对话(黄色部分) 以及最后的输出, 即谈判结束后获得的总分. (红色部分)

本质上, 文章把这个训练问题看做对整体对话(黄色部分)的机器翻译问题, 即, 对整体对话的自然语言进行翻译, 使得机器可以明白谈判最后是什么总得分.

传统的机器翻译可能是这样的编码-解码器:

唯一的不同是, 这里的输出不再是英语翻译的对应法语, 或其他语言, 而是机器理解的总得分情况. (如, 我得到了两个帽子, 总计了6分)

文章深度借鉴了上述思想, 用了一堆神经网络完成对自然语言的理解和推断. 大致包括四个神经网络:

模型1. 负责编码实验对象的目标goal 的神经网络GRUg

模型2.负责编码整个对话自然语言生成的神经网络GRUw

模型3. 负责编码对话->输出的神经网络GRUo1

模型4.负责编码对话<-输出的神经网络GRUo2

这些网络组成的监督学习网络, 学习到了自然语言到谈判目标的依赖关系, 即, 什么样的谈判目标+什么样的谈判对话, 会有什么样的谈判结果.

这里的Input Encoder就是模型1, Output Decoder就是模型3, 4, 最底下横排神经元的生成网络, 就是模型2. (图中最底下的垂直箭头是生成模型的抽样样本, 倾斜的箭头是监督学习的监督样本.)

这样的学习只能教会计算机如何完成一个谈判对话. 而不能让谈判利益最大化. 文章进一步使用增强学习和rollouts算法, 进行模型的第二阶段参数调优:

这次, 一些垂直箭头不见了, 一些抽样被实际的”端到端”人为干预, 实现了增强学习, 把着重点落在如何在特定时刻, 使用更好的谈判语句.

事实上, 作者为了优化这一谈判能力, 煞费苦心:

在每个轮到机器的谈判阶段, 用rollouts算法计算多个候选答案的可能性(Candidate responses) , 模拟进一步的谈判知道结束, 比较最终的总分, 选取最佳谈判方案.

所以总结整个训练过程分为两个阶段:

l. 训练所有神经网络使得计算机学会基本的谈判行为和流程

2. 优化模型, 用强化学习获得的经验, 优化每一步谈判的最佳选择 (想必会吸取以前谈判失败的教训)

最后, 我们来看看几个有意思的模型训练后谈判实例:

上图, 如果强化学习用的过度, 机器会固执地坚持自己的选择, 完全不考虑别人的感受 !

上图, 在可能共赢的情况下, 机器也可能做出让步.

有时, 机器开始也会表现出对一些没有价值物件的兴趣, 为之后的谈判增加筹码.

 

源码运行结果:

(py30) yanchao727@yanchao727-VirtualBox:~/software/end-to-end-negotiator/src$ python train.py   --data data/negotiate    --bsz 16   --clip 0.5   --decay_every 1   --decay_rate 5.0   --dropout 0.5   --init_range 0.1   --lr 1   --max_epoch 30   --min_lr 0.01   --momentum 0.1   --nembed_ctx 64   --nembed_word 256   --nesterov   --nhid_attn 256   --nhid_ctx 64   --nhid_lang 128   --nhid_sel 256   --nhid_strat 128   --sel_weight 0.5 dataset data/negotiate/train.txt, total 687919, unks 8718, ratio 1.27%
dataset data/negotiate/val.txt, total 74653, unks 914, ratio 1.22%
dataset data/negotiate/test.txt, total 70262, unks 847, ratio 1.21%
| epoch 001 | trainloss 3.583 | trainppl 35.987 | s/epoch 623.66 | lr 1.00000000
| epoch 001 | validloss 2.632 | validppl 13.903
| epoch 001 | validselectloss 1.241 | validselectppl 3.460
| epoch 002 | trainloss 2.866 | trainppl 17.563 | s/epoch 696.01 | lr 1.00000000
| epoch 002 | validloss 2.325 | validppl 10.225
| epoch 002 | validselectloss 1.173 | validselectppl 3.232
| epoch 003 | trainloss 2.623 | trainppl 13.778 | s/epoch 670.58 | lr 1.00000000
| epoch 003 | validloss 2.146 | validppl 8.547
| epoch 003 | validselectloss 0.950 | validselectppl 2.585
| epoch 004 | trainloss 2.396 | trainppl 10.983 | s/epoch 648.90 | lr 1.00000000
| epoch 004 | validloss 2.018 | validppl 7.520
| epoch 004 | validselectloss 0.726 | validselectppl 2.066
| epoch 005 | trainloss 2.210 | trainppl 9.115 | s/epoch 701.42 | lr 1.00000000
| epoch 005 | validloss 1.988 | validppl 7.302
| epoch 005 | validselectloss 0.653 | validselectppl 1.921
| epoch 006 | trainloss 2.122 | trainppl 8.345 | s/epoch 639.28 | lr 1.00000000
| epoch 006 | validloss 1.939 | validppl 6.952
| epoch 006 | validselectloss 0.539 | validselectppl 1.715
| epoch 007 | trainloss 2.055 | trainppl 7.805 | s/epoch 671.96 | lr 1.00000000
| epoch 007 | validloss 1.891 | validppl 6.629
| epoch 007 | validselectloss 0.491 | validselectppl 1.634
| epoch 008 | trainloss 2.005 | trainppl 7.425 | s/epoch 229997.76 | lr 1.00000000
| epoch 008 | validloss 1.893 | validppl 6.637
| epoch 008 | validselectloss 0.444 | validselectppl 1.559
| epoch 009 | trainloss 1.963 | trainppl 7.117 | s/epoch 662.04 | lr 1.00000000
| epoch 009 | validloss 1.887 | validppl 6.602
| epoch 009 | validselectloss 0.469 | validselectppl 1.599
| epoch 010 | trainloss 1.923 | trainppl 6.841 | s/epoch 645.77 | lr 1.00000000
| epoch 010 | validloss 1.851 | validppl 6.366
| epoch 010 | validselectloss 0.358 | validselectppl 1.430
| epoch 011 | trainloss 1.887 | trainppl 6.597 | s/epoch 672.44 | lr 1.00000000
| epoch 011 | validloss 1.843 | validppl 6.313
| epoch 011 | validselectloss 0.334 | validselectppl 1.397
| epoch 012 | trainloss 1.853 | trainppl 6.382 | s/epoch 713.18 | lr 1.00000000
| epoch 012 | validloss 1.839 | validppl 6.288
| epoch 012 | validselectloss 0.286 | validselectppl 1.331
| epoch 013 | trainloss 1.829 | trainppl 6.229 | s/epoch 685.65 | lr 1.00000000
| epoch 013 | validloss 1.833 | validppl 6.254
| epoch 013 | validselectloss 0.282 | validselectppl 1.325
| epoch 014 | trainloss 1.803 | trainppl 6.070 | s/epoch 700.42 | lr 1.00000000
| epoch 014 | validloss 1.885 | validppl 6.587
| epoch 014 | validselectloss 0.219 | validselectppl 1.245
| epoch 015 | trainloss 1.778 | trainppl 5.921 | s/epoch 694.53 | lr 1.00000000
| epoch 015 | validloss 1.840 | validppl 6.297
| epoch 015 | validselectloss 0.255 | validselectppl 1.291
| epoch 016 | trainloss 1.757 | trainppl 5.798 | s/epoch 666.82 | lr 1.00000000
| epoch 016 | validloss 1.831 | validppl 6.239
| epoch 016 | validselectloss 0.220 | validselectppl 1.246
| epoch 017 | trainloss 1.741 | trainppl 5.703 | s/epoch 636.02 | lr 1.00000000
| epoch 017 | validloss 1.796 | validppl 6.027
| epoch 017 | validselectloss 0.216 | validselectppl 1.242
| epoch 018 | trainloss 1.720 | trainppl 5.586 | s/epoch 651.39 | lr 1.00000000
| epoch 018 | validloss 1.853 | validppl 6.380
| epoch 018 | validselectloss 0.179 | validselectppl 1.196
| epoch 019 | trainloss 1.707 | trainppl 5.510 | s/epoch 708.23 | lr 1.00000000
| epoch 019 | validloss 1.791 | validppl 5.995
| epoch 019 | validselectloss 0.214 | validselectppl 1.239
| epoch 020 | trainloss 1.693 | trainppl 5.438 | s/epoch 716.78 | lr 1.00000000
| epoch 020 | validloss 1.814 | validppl 6.134
| epoch 020 | validselectloss 0.168 | validselectppl 1.182
| epoch 021 | trainloss 1.676 | trainppl 5.347 | s/epoch 674.25 | lr 1.00000000
| epoch 021 | validloss 1.808 | validppl 6.101
| epoch 021 | validselectloss 0.173 | validselectppl 1.189
| epoch 022 | trainloss 1.665 | trainppl 5.287 | s/epoch 671.06 | lr 1.00000000
| epoch 022 | validloss 1.792 | validppl 6.001
| epoch 022 | validselectloss 0.167 | validselectppl 1.182
| epoch 023 | trainloss 1.651 | trainppl 5.211 | s/epoch 678.09 | lr 1.00000000
| epoch 023 | validloss 1.782 | validppl 5.941
| epoch 023 | validselectloss 0.131 | validselectppl 1.140
| epoch 024 | trainloss 1.641 | trainppl 5.159 | s/epoch 686.55 | lr 1.00000000
| epoch 024 | validloss 1.783 | validppl 5.950
| epoch 024 | validselectloss 0.160 | validselectppl 1.173
| epoch 025 | trainloss 1.625 | trainppl 5.078 | s/epoch 720.00 | lr 1.00000000
| epoch 025 | validloss 1.808 | validppl 6.097
| epoch 025 | validselectloss 0.136 | validselectppl 1.145
| epoch 026 | trainloss 1.619 | trainppl 5.046 | s/epoch 727.31 | lr 1.00000000
| epoch 026 | validloss 1.772 | validppl 5.880
| epoch 026 | validselectloss 0.139 | validselectppl 1.149
| epoch 027 | trainloss 1.609 | trainppl 4.997 | s/epoch 725.96 | lr 1.00000000
| epoch 027 | validloss 1.768 | validppl 5.857
| epoch 027 | validselectloss 0.158 | validselectppl 1.172
| epoch 028 | trainloss 1.600 | trainppl 4.954 | s/epoch 736.97 | lr 1.00000000
| epoch 028 | validloss 1.772 | validppl 5.883
| epoch 028 | validselectloss 0.130 | validselectppl 1.139
| epoch 029 | trainloss 1.590 | trainppl 4.903 | s/epoch 674.34 | lr 1.00000000
| epoch 029 | validloss 1.773 | validppl 5.891
| epoch 029 | validselectloss 0.105 | validselectppl 1.110
| epoch 030 | trainloss 1.584 | trainppl 4.874 | s/epoch 664.10 | lr 1.00000000
| epoch 030 | validloss 1.797 | validppl 6.031
| epoch 030 | validselectloss 0.099 | validselectppl 1.104
| start annealing | best validselectloss 0.099 | best validselectppl 1.104
| epoch 031 | trainloss 1.506 | trainppl 4.509 | s/epoch 640.44 | lr 0.20000000
| epoch 031 | validloss 1.728 | validppl 5.627
| epoch 031 | validselectloss 0.090 | validselectppl 1.095
| epoch 032 | trainloss 1.483 | trainppl 4.405 | s/epoch 681.86 | lr 0.04000000
| epoch 032 | validloss 1.725 | validppl 5.611
| epoch 032 | validselectloss 0.087 | validselectppl 1.091
final selectppl 1.091

 

 

参考文献:

  1. https://arxiv.org/pdf/1706.05125.pdf
  2. https://arxiv.org/pdf/1409.1259.pdfFAIR
  3. https://zhuanlan.zhihu.com/p/27410000

本文采用署名 – 非商业性使用 – 禁止演绎 3.0 中国大陆许可协议进行许可。著作权属于“David 9的博客”原创,如需转载,请联系微信: david9ml,或邮箱:yanchao727@gmail.com

或直接扫二维码:

发布者

David 9

邮箱:yanchao727@gmail.com 微信: david9ml

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注