端到端学习是那么吸引人, 因为它与理想的”自治”学习是那么近. — David 9
我们离完全”自治”的AI系统还很远很远, 没有自我采集样本的能力, 没有自己构建模型的能力, 也缺少”端到端” 学什么就像什么的灵活性. 而最近Facebook 人工智能研究所(FAIR)的研究人员公开了一个具有谈判新能力的对话智能体(dialog agents),并开源了其代码, 在”端到端” 这一方向上更进了一步:
-
论文地址:Deal or No Deal? End-to-End Learning for Negotiation Dialogues
-
开源地址:https://github.com/facebookresearch/end-to-end-negotiator
这篇文章的突破仅限于智能对话, 更像是一篇专利, 教大家如何用一堆神经网络训练一个智能对话来获得谈判最终利益. 另外值得注意的是该pytorch项目虽然开源, 但是是经过 creativecommons的NonCommercial 4.0 非商业化协议保护的, 即, 你可以研究和使用代码, 但是你不能直接用它做商业用途.
言归正传, David 9 想说的是, 这个近乎科幻的对话机器人, 其实并没有那么神奇.
首先看看Facebook一伙人怎么收集对话(dialog)数据的 :
Facebook这伙人收集的数据是从亚马逊 Mechanical Turk 交易网站上 买来的, $0.15一个对话, 总共买了5808个对话.
瞧, 我们还在一个”茹毛饮血” 用钱买数据的时代.
想要赚$0.15一个对话很简单, 如上图, 你只要根据每个物品的计分(书本得0分, 帽子得7分, 篮球得1分) 和对方谈判需要的物件, 尽量获得最大总得分, 谈判成交, 即可 .
在如此简单的实验环境和语料集合下(只有书本, 帽子, 篮球等简单物件), 作者证明了我们可以训练神经网络使得它具有一定的谈判能力.
总体训练方式也比较简单:
如上图, 把完整的对话分割成两个实验对象(Agent 1 和 Agent 2) 各自角度的独立对话. 这样, 输入就是每个实验对象的谈判目标goal ( 即, 各个物件价值, 如: “书本得3分, 帽子得2分, 篮球得1分”). (紫色部分) 训练数据还包含整体对话(黄色部分) 以及最后的输出, 即谈判结束后获得的总分. (红色部分)
本质上, 文章把这个训练问题看做对整体对话(黄色部分)的机器翻译问题, 即, 对整体对话的自然语言进行翻译, 使得机器可以明白谈判最后是什么总得分.
传统的机器翻译可能是这样的编码-解码器:
唯一的不同是, 这里的输出不再是英语翻译的对应法语, 或其他语言, 而是机器理解的总得分情况. (如, 我得到了两个帽子, 总计了6分)
文章深度借鉴了上述思想, 用了一堆神经网络完成对自然语言的理解和推断. 大致包括四个神经网络:
模型1. 负责编码实验对象的目标goal 的神经网络GRUg
模型2.负责编码整个对话自然语言生成的神经网络GRUw
模型3. 负责编码对话->输出的神经网络GRUo1
模型4.负责编码对话<-输出的神经网络GRUo2
这些网络组成的监督学习网络, 学习到了自然语言到谈判目标的依赖关系, 即, 什么样的谈判目标+什么样的谈判对话, 会有什么样的谈判结果.
这里的Input Encoder就是模型1, Output Decoder就是模型3, 4, 最底下横排神经元的生成网络, 就是模型2. (图中最底下的垂直箭头是生成模型的抽样样本, 倾斜的箭头是监督学习的监督样本.)
这样的学习只能教会计算机如何完成一个谈判对话. 而不能让谈判利益最大化. 文章进一步使用增强学习和rollouts算法, 进行模型的第二阶段参数调优:
这次, 一些垂直箭头不见了, 一些抽样被实际的”端到端”人为干预, 实现了增强学习, 把着重点落在如何在特定时刻, 使用更好的谈判语句.
在每个轮到机器的谈判阶段, 用rollouts算法计算多个候选答案的可能性(Candidate responses) , 模拟进一步的谈判知道结束, 比较最终的总分, 选取最佳谈判方案.
所以总结整个训练过程分为两个阶段:
l. 训练所有神经网络使得计算机学会基本的谈判行为和流程
2. 优化模型, 用强化学习获得的经验, 优化每一步谈判的最佳选择 (想必会吸取以前谈判失败的教训)
最后, 我们来看看几个有意思的模型训练后谈判实例:
上图, 如果强化学习用的过度, 机器会固执地坚持自己的选择, 完全不考虑别人的感受 !
有时, 机器开始也会表现出对一些没有价值物件的兴趣, 为之后的谈判增加筹码.
源码运行结果:
(py30) yanchao727@yanchao727-VirtualBox:~/software/end-to-end-negotiator/src$ python train.py --data data/negotiate --bsz 16 --clip 0.5 --decay_every 1 --decay_rate 5.0 --dropout 0.5 --init_range 0.1 --lr 1 --max_epoch 30 --min_lr 0.01 --momentum 0.1 --nembed_ctx 64 --nembed_word 256 --nesterov --nhid_attn 256 --nhid_ctx 64 --nhid_lang 128 --nhid_sel 256 --nhid_strat 128 --sel_weight 0.5 dataset data/negotiate/train.txt, total 687919, unks 8718, ratio 1.27% dataset data/negotiate/val.txt, total 74653, unks 914, ratio 1.22% dataset data/negotiate/test.txt, total 70262, unks 847, ratio 1.21% | epoch 001 | trainloss 3.583 | trainppl 35.987 | s/epoch 623.66 | lr 1.00000000 | epoch 001 | validloss 2.632 | validppl 13.903 | epoch 001 | validselectloss 1.241 | validselectppl 3.460 | epoch 002 | trainloss 2.866 | trainppl 17.563 | s/epoch 696.01 | lr 1.00000000 | epoch 002 | validloss 2.325 | validppl 10.225 | epoch 002 | validselectloss 1.173 | validselectppl 3.232 | epoch 003 | trainloss 2.623 | trainppl 13.778 | s/epoch 670.58 | lr 1.00000000 | epoch 003 | validloss 2.146 | validppl 8.547 | epoch 003 | validselectloss 0.950 | validselectppl 2.585 | epoch 004 | trainloss 2.396 | trainppl 10.983 | s/epoch 648.90 | lr 1.00000000 | epoch 004 | validloss 2.018 | validppl 7.520 | epoch 004 | validselectloss 0.726 | validselectppl 2.066 | epoch 005 | trainloss 2.210 | trainppl 9.115 | s/epoch 701.42 | lr 1.00000000 | epoch 005 | validloss 1.988 | validppl 7.302 | epoch 005 | validselectloss 0.653 | validselectppl 1.921 | epoch 006 | trainloss 2.122 | trainppl 8.345 | s/epoch 639.28 | lr 1.00000000 | epoch 006 | validloss 1.939 | validppl 6.952 | epoch 006 | validselectloss 0.539 | validselectppl 1.715 | epoch 007 | trainloss 2.055 | trainppl 7.805 | s/epoch 671.96 | lr 1.00000000 | epoch 007 | validloss 1.891 | validppl 6.629 | epoch 007 | validselectloss 0.491 | validselectppl 1.634 | epoch 008 | trainloss 2.005 | trainppl 7.425 | s/epoch 229997.76 | lr 1.00000000 | epoch 008 | validloss 1.893 | validppl 6.637 | epoch 008 | validselectloss 0.444 | validselectppl 1.559 | epoch 009 | trainloss 1.963 | trainppl 7.117 | s/epoch 662.04 | lr 1.00000000 | epoch 009 | validloss 1.887 | validppl 6.602 | epoch 009 | validselectloss 0.469 | validselectppl 1.599 | epoch 010 | trainloss 1.923 | trainppl 6.841 | s/epoch 645.77 | lr 1.00000000 | epoch 010 | validloss 1.851 | validppl 6.366 | epoch 010 | validselectloss 0.358 | validselectppl 1.430 | epoch 011 | trainloss 1.887 | trainppl 6.597 | s/epoch 672.44 | lr 1.00000000 | epoch 011 | validloss 1.843 | validppl 6.313 | epoch 011 | validselectloss 0.334 | validselectppl 1.397 | epoch 012 | trainloss 1.853 | trainppl 6.382 | s/epoch 713.18 | lr 1.00000000 | epoch 012 | validloss 1.839 | validppl 6.288 | epoch 012 | validselectloss 0.286 | validselectppl 1.331 | epoch 013 | trainloss 1.829 | trainppl 6.229 | s/epoch 685.65 | lr 1.00000000 | epoch 013 | validloss 1.833 | validppl 6.254 | epoch 013 | validselectloss 0.282 | validselectppl 1.325 | epoch 014 | trainloss 1.803 | trainppl 6.070 | s/epoch 700.42 | lr 1.00000000 | epoch 014 | validloss 1.885 | validppl 6.587 | epoch 014 | validselectloss 0.219 | validselectppl 1.245 | epoch 015 | trainloss 1.778 | trainppl 5.921 | s/epoch 694.53 | lr 1.00000000 | epoch 015 | validloss 1.840 | validppl 6.297 | epoch 015 | validselectloss 0.255 | validselectppl 1.291 | epoch 016 | trainloss 1.757 | trainppl 5.798 | s/epoch 666.82 | lr 1.00000000 | epoch 016 | validloss 1.831 | validppl 6.239 | epoch 016 | validselectloss 0.220 | validselectppl 1.246 | epoch 017 | trainloss 1.741 | trainppl 5.703 | s/epoch 636.02 | lr 1.00000000 | epoch 017 | validloss 1.796 | validppl 6.027 | epoch 017 | validselectloss 0.216 | validselectppl 1.242 | epoch 018 | trainloss 1.720 | trainppl 5.586 | s/epoch 651.39 | lr 1.00000000 | epoch 018 | validloss 1.853 | validppl 6.380 | epoch 018 | validselectloss 0.179 | validselectppl 1.196 | epoch 019 | trainloss 1.707 | trainppl 5.510 | s/epoch 708.23 | lr 1.00000000 | epoch 019 | validloss 1.791 | validppl 5.995 | epoch 019 | validselectloss 0.214 | validselectppl 1.239 | epoch 020 | trainloss 1.693 | trainppl 5.438 | s/epoch 716.78 | lr 1.00000000 | epoch 020 | validloss 1.814 | validppl 6.134 | epoch 020 | validselectloss 0.168 | validselectppl 1.182 | epoch 021 | trainloss 1.676 | trainppl 5.347 | s/epoch 674.25 | lr 1.00000000 | epoch 021 | validloss 1.808 | validppl 6.101 | epoch 021 | validselectloss 0.173 | validselectppl 1.189 | epoch 022 | trainloss 1.665 | trainppl 5.287 | s/epoch 671.06 | lr 1.00000000 | epoch 022 | validloss 1.792 | validppl 6.001 | epoch 022 | validselectloss 0.167 | validselectppl 1.182 | epoch 023 | trainloss 1.651 | trainppl 5.211 | s/epoch 678.09 | lr 1.00000000 | epoch 023 | validloss 1.782 | validppl 5.941 | epoch 023 | validselectloss 0.131 | validselectppl 1.140 | epoch 024 | trainloss 1.641 | trainppl 5.159 | s/epoch 686.55 | lr 1.00000000 | epoch 024 | validloss 1.783 | validppl 5.950 | epoch 024 | validselectloss 0.160 | validselectppl 1.173 | epoch 025 | trainloss 1.625 | trainppl 5.078 | s/epoch 720.00 | lr 1.00000000 | epoch 025 | validloss 1.808 | validppl 6.097 | epoch 025 | validselectloss 0.136 | validselectppl 1.145 | epoch 026 | trainloss 1.619 | trainppl 5.046 | s/epoch 727.31 | lr 1.00000000 | epoch 026 | validloss 1.772 | validppl 5.880 | epoch 026 | validselectloss 0.139 | validselectppl 1.149 | epoch 027 | trainloss 1.609 | trainppl 4.997 | s/epoch 725.96 | lr 1.00000000 | epoch 027 | validloss 1.768 | validppl 5.857 | epoch 027 | validselectloss 0.158 | validselectppl 1.172 | epoch 028 | trainloss 1.600 | trainppl 4.954 | s/epoch 736.97 | lr 1.00000000 | epoch 028 | validloss 1.772 | validppl 5.883 | epoch 028 | validselectloss 0.130 | validselectppl 1.139 | epoch 029 | trainloss 1.590 | trainppl 4.903 | s/epoch 674.34 | lr 1.00000000 | epoch 029 | validloss 1.773 | validppl 5.891 | epoch 029 | validselectloss 0.105 | validselectppl 1.110 | epoch 030 | trainloss 1.584 | trainppl 4.874 | s/epoch 664.10 | lr 1.00000000 | epoch 030 | validloss 1.797 | validppl 6.031 | epoch 030 | validselectloss 0.099 | validselectppl 1.104 | start annealing | best validselectloss 0.099 | best validselectppl 1.104 | epoch 031 | trainloss 1.506 | trainppl 4.509 | s/epoch 640.44 | lr 0.20000000 | epoch 031 | validloss 1.728 | validppl 5.627 | epoch 031 | validselectloss 0.090 | validselectppl 1.095 | epoch 032 | trainloss 1.483 | trainppl 4.405 | s/epoch 681.86 | lr 0.04000000 | epoch 032 | validloss 1.725 | validppl 5.611 | epoch 032 | validselectloss 0.087 | validselectppl 1.091 final selectppl 1.091
参考文献:
- https://arxiv.org/pdf/1706.05125.pdf
- https://arxiv.org/pdf/1409.1259.pdfFAIR
- https://zhuanlan.zhihu.com/p/27410000
本文采用署名 – 非商业性使用 – 禁止演绎 3.0 中国大陆许可协议进行许可。著作权属于“David 9的博客”原创,如需转载,请联系微信: david9ml,或邮箱:yanchao727@gmail.com
或直接扫二维码:
David 9
Latest posts by David 9 (see all)
- 修订特征已经变得切实可行, “特征矫正工程”是否会成为潮流? - 27 3 月, 2024
- 量子计算系列#2 : 量子机器学习与量子深度学习补充资料,QML,QeML,QaML - 29 2 月, 2024
- “现象意识”#2:用白盒的视角研究意识和大脑,会是什么景象?微意识,主体感,超心智,意识中层理论 - 16 2 月, 2024