Boosting のまとめ
Boosting のまとめ
ブースティングとは、Weak Learner(弱い学習機)のアンサンブルを作ることで Strong Learner(強い学習機)を生成するのを目的とした機械学習のアルゴリズムの一種である。ブースティングをまとめるために Weak Learner (WL) と Strong Learner (SL) とは何か、その概念を本記事で簡単に説明したいと考えている。Weak Learner とは、ある特徴量に対して分類しようとした結果がランダムにカテゴリを選ぶよりも精度が高いという若干の相関のある分類機である。基本的には WL の構造はシンプルである。 例えば、ある人を「男性」、「女性」という2つのカテゴリで分類するのが目的としよう。人間を表す特徴量は、例えば以下の通りの4次元ベクトルにしよう。
人の特徴量:(年齢、身長、マイナンバー、過去5年間でスカート・ドレスを履いた回数)
もちろん以上の特徴量(情報)で100%という確率で性別を予測することができない。しかも、全く性別に関係ない要素もある。だが、以下のように分類機を用意すると、どれくらいのな確率で予測できるのであろうか。
以下は、日本における男女平均身長のグラフである。中央値が平均値と異なる可能性が高いが、等しいとしよう。
source: NHK (https://www.nhk.or.jp/ohayou/digest/2017/04/0401.html)
男女平均身長を男女身長の中央値にした場合は、50% 以上の男性は 165cm より背が高く、50% 以上の女性は 165cm より背が低い。 すると、分類機によって正しく分類された人は、50% 以上である。これは、ランダムに性別を選ぶよりも正確だが、165cm より背が高い女性もいれば、 165cm より背が短い男性もいるから100%近づかないであろう。
さらに、もう一つ、以下の分類機を加えよう。
以上の分類機は身長の分類機と同様、100%という確率で性別を予測できないが、ランダムに選ぶよりも正確であろう。 以上の分類機を加え、両方の分類機の分類結果を考慮した上で分類した方が精度が上がるであろうという発想からブースティングが生まれた。 複数の WL を作り、それぞれの WL に重みをつけて、各 WL の結果を重みにかけた総和を出力する。出力するものが SL であり、SL の結果次第で分類(またはランキング)する。 こういったアルゴリズムはブースティングのアルゴリズムである。
ブースティングのアルゴリズムによって、WL の作り方、加え方が異なる。以下、例として LTR4L で実装してある RankBoost と LambdaMART という3つのブースティングアルゴリズムを簡単にまとめ、LTR4L でどのように使われているか、モデルファイルがどのような形をしているか説明しようと考える。
RankBoost
RankBoost ではペアワイズの分布を用意し、分布に基づいて各イタレーション(エポック)で WL を作り追加する。 追加する WL の予測結果によって分布を変え、それを繰り返す。 WL は性別を当てるという例で挙げた WL と同様、ある特徴とある閾値をもって予測する。 ペアワイズの分布が用いられているため学習するには文書2個が必要だが、予測するために必要な文書の個数は1つだけである。 (分布や WL についての詳細は論文を参照。)LTR4L を RankBoost (例えば以下のコマンドとデフォルトコンフィグ)で実行すると、
cd {LTR4L-*.jar Directory}
./train rankboost
以下のようなモデルファイルが作られる。{
"config" : {
"algorithm" : "rankboost",
"numIterations" : 100,
"batchSize" : 0,
"verbose" : true,
"nomodel" : false,
"params" : {
"numSteps" : 10,
"regularization" : {
"regularizer" : "L2",
"rate" : 0.01
}
},
"dataSet" : {
"training" : "data/MQ2008/Fold1/train.txt",
"validation" : "data/MQ2008/Fold1/vali.txt",
"test" : "data/MQ2008/Fold1/test.txt"
},
"model" : {
"format" : "json",
"file" : "model/rankboost-model.json"
},
"evaluation" : {
"evaluator" : "NDCG",
"params" : {
"k" : 10
}
},
"report" : {
"format" : "csv",
"file" : "report/rankboost-report.csv"
}
},
"features" : [ 22, 22, 22, 22, 22, 38, 22, 22, 38, 23, 23, 39, 37, 37, 21, 23, 39, 37, 37, 21, 38, 23, 39, 22, 22, 37, 37, 21, 23, 23, 39, 21, 21, 36, 20, 36, 20, 36, 20, 36, 39, 23, 21, 38, 22, 36, 36, 38, 38, 36, 36, 20, 39, 39, 23, 37, 21, 39, 39, 23, 37, 21, 38, 38, 22, 23, 23, 39, 22, 22, 38, 36, 36, 20, 37, 37, 21, 36, 36, 20, 36, 36, 20, 21, 21, 37, 39, 23, 38, 38, 38, 22, 10, 10, 10, 14, 4, 4, 0, 14 ],
"thresholds" : [ 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.5, 0.5, 0.5, 0.6, 0.6, 0.6, 0.5, 0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.30000000000000004, 0.30000000000000004, 0.30000000000000004, 0.30000000000000004, 0.30000000000000004, 0.30000000000000004, 0.30000000000000004, 0.30000000000000004, 0.6, 0.6, 0.4, 0.5, 0.5, 0.30000000000000004, 0.30000000000000004, 0.6, 0.6, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.7999999999999999, 0.7999999999999999, 0.7999999999999999, 0.7999999999999999, 0.7999999999999999, 0.7999999999999999, 0.7999999999999999, 0.7999999999999999, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.7999999999999999, 0.7999999999999999, 0.7999999999999999, 0.8999999999999999, 0.8999999999999999, 0.8999999999999999, 0.8999999999999999, 0.8999999999999999, 0.8999999999999999, 0.8999999999999999, 0.8999999999999999, 0.8999999999999999, 0.8999999999999999, 0.8999999999999999, 0.8999999999999999, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2 ],
"weights" : [ 0.41786064627910435, 0.6019294935562776, 0.8765316515447996, 1.2926812328638277, 1.9261765640212982, 1.7517158411371152, 1.1027753871492012, 1.616065850594727, 1.9348789644695028, 1.2068653815682175, 1.8135181103451798, 1.9050064900687507, 1.5097312329134034, 2.2476681697203134, 2.2047693832305426, 1.6794383221124585, 2.151269553805371, 1.4519278184742754, 2.1564488874569196, 1.8774324016329669, 1.7218843172353573, 1.5824009717596488, 2.0980593500406535, 1.429249907459846, 2.1215754530838438, 1.4191389172985864, 2.106017427813355, 1.7572220572522048, 1.5022554606081513, 2.2337378019636187, 2.28305909297725, 1.3661708054549946, 2.0244144900259444, 1.8902466751017266, 1.8447523347032473, 2.2221343944407566, 1.7019984430829036, 2.264463646804348, 1.6109941987459937, 2.0984156794780464, 1.609106451720564, 1.993747209603476, 1.7298564460664592, 1.617493734625768, 2.098846157685593, 1.5331976445154625, 2.2811880202448647, 1.3732325256301914, 2.03530365602232, 1.2500009328523054, 1.844825263188075, 1.9949523304194827, 1.3859424359081916, 2.0548946173431784, 1.8433723288992776, 1.5805769345391614, 1.7193867171357988, 1.1793151462900695, 1.7351133908954641, 1.7580843715574859, 1.6620924499223642, 2.1573828150208763, 1.237895273318067, 1.8260599649157931, 1.6954283598556683, 1.166399682676041, 1.7150304855788159, 2.0171101017105824, 1.2638683623704412, 1.8663094805123417, 2.1229238835780313, 1.2469004328768283, 1.8400200336490882, 1.8760549719323711, 1.2064089900599424, 1.7772059841695962, 2.070830580057891, 1.2905340376938987, 1.9075856451537092, 2.0062336511438814, 1.1174525828475614, 1.638816807314203, 1.854320860231896, 1.2511791149392368, 1.846651069453228, 1.8309695097149061, 1.5955632646972602, 1.667220714440436, 0.9387118213719363, 1.3591802738176193, 2.013632059695857, 2.069614900394956, 0.7562310725411754, 1.0842207179039267, 1.5934317478203175, 1.8652512117984072, 1.0743671949051556, 1.5782052357325353, 2.226855053899937, 0.9795899184679753 ]
}
features、thresholds、そして weights の数字(ダブル型)の配列は各 WL の特徴、閾値、そして重み(α)である。例えば、i番目の feature、threshold、 weight は i番目の WL の特徴、閾値、重みである。predict を上記のモデルで実行すると、WL のアンサンブルをモデルファイルに書いてある特徴・閾値・重みで作り、予測する。
LambdaMART
LambdaMART で使われる WL は決定木である。イメージとしては、RankBoost で使われる WL で作られた木である。 例えば、性別を当てる二つの WL を用い以下のように決定木が作れる。決定木を作るためには定義された損失関数を用いて、リーフ(ノード)追加時にそれを最小化する特徴と閾値を選ぶ。 リーフが出力スコアの計算時には、LambdaRank というニューラルネットワークのアルゴリズムで計算する Λ を用いる。 詳細は論文を参照。
同じく LambdaMART で実行すると、以下のようなモデルファイルが作られる。
{
"config" : {
"algorithm" : "LambdaMart",
"numIterations" : 100,
"batchSize" : 0,
"verbose" : true,
"nomodel" : false,
"params" : {
"numTrees" : 15,
"numLeaves" : 4,
"numSteps" : 10,
"learningRate" : 0.05,
"optimizer" : "adam",
"weightInit" : "xavier",
"regularization" : {
"regularizer" : "L2",
"rate" : 0.01
}
},
"dataSet" : {
"training" : "data/MQ2008/Fold1/train.txt",
"validation" : "data/MQ2008/Fold1/vali.txt",
"test" : "data/MQ2008/Fold1/test.txt"
},
"model" : {
"format" : "json",
"file" : "model/lambdamart-model.json"
},
"evaluation" : {
"evaluator" : "NDCG",
"params" : {
"k" : 10
}
},
"report" : {
"format" : "csv",
"file" : "report/lambdamart-report.csv"
}
},
"treeModels" : [ {
"config" : null,
"leafIds" : [ 0, 1, 3, 7, 8, 4, 2 ],
"featureIds" : [ 38, 38, 39, -1, -1, -1, -1 ],
"thresh" : [ 0.748092, 0.523628, 0.4, "-Infinity", "-Infinity", "-Infinity", "-Infinity" ],
"scores" : [ 0.0, 0.0, 0.0, 0.04757469250124157, 0.044148895268034685, -0.026720292627365028, 0.045137089557086 ]
}, {
"config" : null,
"leafIds" : [ 0, 1, 3, 7, 8, 4, 2 ],
"featureIds" : [ 22, 38, 39, -1, -1, -1, -1 ],
"thresh" : [ 0.695597, 0.5, 0.4, "-Infinity", "-Infinity", "-Infinity", "-Infinity" ],
"scores" : [ 0.0, 0.0, 0.0, 0.03169404750084058, -0.028349941303858234, 0.019450979352077697, 0.03906552378899862 ]
}, {
"config" : null,
"leafIds" : [ 0, 1, 3, 7, 8, 4, 2 ],
"featureIds" : [ 39, 22, 38, -1, -1, -1, -1 ],
"thresh" : [ 0.670009, 0.6, 0.4039212, "-Infinity", "-Infinity", "-Infinity", "-Infinity" ],
"scores" : [ 0.0, 0.0, 0.0, 0.08088942790279205, 0.13224122818059064, -0.16040080032741358, 0.10216624680738347 ]
}, {
"config" : null,
"leafIds" : [ 0, 1, 3, 7, 8, 4, 2 ],
"featureIds" : [ 23, 22, 38, -1, -1, -1, -1 ],
"thresh" : [ 0.680565, 0.6, 0.4, "-Infinity", "-Infinity", "-Infinity", "-Infinity" ],
"scores" : [ 0.0, 0.0, 0.0, 0.07837243733928041, 0.041224628957030246, 0.00417605005514371, 0.07721442404331602 ]
}, {
"config" : null,
"leafIds" : [ 0, 1, 3, 7, 8, 4, 2 ],
"featureIds" : [ 37, 22, 38, -1, -1, -1, -1 ],
"thresh" : [ 0.712139, 0.6, 0.364439, "-Infinity", "-Infinity", "-Infinity", "-Infinity" ],
"scores" : [ 0.0, 0.0, 0.0, 0.07452761986199616, 0.12291127583961003, -0.006119842773497125, 0.12162400021637708 ]
}, {
"config" : null,
"leafIds" : [ 0, 1, 3, 7, 8, 4, 2 ],
"featureIds" : [ 21, 22, 39, -1, -1, -1, -1 ],
"thresh" : [ 0.651999, 0.6, 0.4, "-Infinity", "-Infinity", "-Infinity", "-Infinity" ],
"scores" : [ 0.0, 0.0, 0.0, 0.0741931797831924, 0.11921222733167222, -0.013083465040347344, 0.11962579306542194 ]
}, {
"config" : null,
"leafIds" : [ 0, 1, 3, 7, 8, 4, 2 ],
"featureIds" : [ 20, 22, 39, -1, -1, -1, -1 ],
"thresh" : [ 0.726341, 0.6, 0.4, "-Infinity", "-Infinity", "-Infinity", "-Infinity" ],
"scores" : [ 0.0, 0.0, 0.0, 0.08630402592857363, 0.09020437037468336, 0.028866515737240826, 0.11375557544057655 ]
}, {
"config" : null,
"leafIds" : [ 0, 1, 3, 7, 8, 4, 2 ],
"featureIds" : [ 36, 22, 38, -1, -1, -1, -1 ],
"thresh" : [ 0.730147, 0.6, 0.364439, "-Infinity", "-Infinity", "-Infinity", "-Infinity" ],
"scores" : [ 0.0, 0.0, 0.0, 0.09136218557381391, 0.08321744787301914, -0.001967960819564388, 0.14357039789922074 ]
}, {
"config" : null,
"leafIds" : [ 0, 1, 3, 7, 8, 4, 2 ],
"featureIds" : [ 4, 39, 39, -1, -1, -1, -1 ],
"thresh" : [ 0.020699, 0.6, 0.4794520000000001, "-Infinity", "-Infinity", "-Infinity", "-Infinity" ],
"scores" : [ 0.0, 0.0, 0.0, 0.08013146840105015, 0.09545478632715355, 0.02934378042488794, 0.07077545026648407 ]
}, {
"config" : null,
"leafIds" : [ 0, 1, 3, 7, 8, 4, 2 ],
"featureIds" : [ 0, 39, 39, -1, -1, -1, -1 ],
"thresh" : [ 0.018605, 0.6, 0.4794520000000001, "-Infinity", "-Infinity", "-Infinity", "-Infinity" ],
"scores" : [ 0.0, 0.0, 0.0, 0.0770093310710862, 0.0709371115203059, 0.10898898009538896, 0.07843341281391535 ]
}, {
"config" : null,
"leafIds" : [ 0, 1, 2, 5, 11, 12, 6 ],
"featureIds" : [ 30, -1, 22, 39, -1, -1, -1 ],
"thresh" : [ 0.744277, "-Infinity", 0.7, 0.6, "-Infinity", "-Infinity", "-Infinity" ],
"scores" : [ 0.0, 0.10586981897596179, 0.0, 0.0, 0.07068146975207057, 0.0997888273784527, 0.10078212576457231 ]
}, {
"config" : null,
"leafIds" : [ 0, 1, 2, 5, 11, 12, 6 ],
"featureIds" : [ 27, -1, 22, 16, -1, -1, -1 ],
"thresh" : [ 0.525036, "-Infinity", 0.7, 0.6, "-Infinity", "-Infinity", "-Infinity" ],
"scores" : [ 0.0, 0.08871128488654877, 0.0, 0.0, 0.07523278574635008, 0.03710151109198109, 0.0898850290725211 ]
}, {
"config" : null,
"leafIds" : [ 0, 1, 2, 5, 11, 12, 6 ],
"featureIds" : [ 29, -1, 22, 39, -1, -1, -1 ],
"thresh" : [ 0.516782, "-Infinity", 0.7, 0.6, "-Infinity", "-Infinity", "-Infinity" ],
"scores" : [ 0.0, 0.07384295639836369, 0.0, 0.0, 0.08870516486892165, 0.023748922880171792, 0.08271444274742216 ]
}, {
"config" : null,
"leafIds" : [ 0, 1, 3, 7, 8, 4, 2 ],
"featureIds" : [ 15, 37, 39, -1, -1, -1, -1 ],
"thresh" : [ 0.004402, 0.7999999999999999, 0.5, "-Infinity", "-Infinity", "-Infinity", "-Infinity" ],
"scores" : [ 0.0, 0.0, 0.0, 0.07532283071858566, 0.08651952072801597, 0.04125361938656299, 0.0697530492150944 ]
}, {
"config" : null,
"leafIds" : [ 0, 1, 2, 5, 11, 12, 6 ],
"featureIds" : [ 31, -1, 22, 39, -1, -1, -1 ],
"thresh" : [ 0.44423, "-Infinity", 0.7, 0.6, "-Infinity", "-Infinity", "-Infinity" ],
"scores" : [ 0.0, 0.10468543490261305, 0.0, 0.0, 0.08091616052015302, -0.015292945578252873, 0.11528456381498506 ]
} ]
treeModels というところに決定木のアンサンブルの各木の各ノードの特徴、閾値、スコアなどが書いてある。
leafId は ノードの位置を指す。以下の通り leafId が付けられる。一番目の決定木を説明する。
{
"config" : null,
"leafIds" : [ 0, 1, 3, 7, 8, 4, 2 ],
"featureIds" : [ 38, 38, 39, -1, -1, -1, -1 ],
"thresh" : [ 0.748092, 0.523628, 0.4, "-Infinity", "-Infinity", "-Infinity", "-Infinity" ],
"scores" : [ 0.0, 0.0, 0.0, 0.04757469250124157, 0.044148895268034685, -0.026720292627365028, 0.045137089557086 ]
}
上記の決定木の構造は以下の通りである。ノードがなければ、その位置に該当する id はモデルファイルに書き込まれない。 featureIds, thresh, scores は、leafId の添字と一緒になっている。例えば、leafId 3 に対応する featureId, threshold, score は featureIds, thresh, scores の三番目の項目(index = 2) に書いてある。
thresh に "-Infinity" や featureId に -1 が書いてある場合は、そのノードがリーフ(出力ノード・下にノードがない)になっており、score が 0 の場合は ノードの下にノードがあるという意味をしている。
以上