「機械学習の論文をスクラッチから実装しよう!」LambdaMART の数式をコードに落とす
「機械学習の論文をスクラッチから実装しよう!」というシリーズは、機械学習で現れる数式をコードに落とすことを目的としています。論文で数式がでてくると、「これは 具体的にはどういう式なのでしょう」と思うことは多いでしょう。その数式が実際のコードで実装された形であれば、漠然としていた数式がわかりやすくなるのではないかと 考え、本シリーズを書くことにしました。本記事では、数式の実装に重点を置いてコードを書きます。本記事に含まれたコードはアルゴリズムの評価、判別などの完全実装ではなく、 むしろ機械学習のアルゴリズムを実装するにあたって参考になればなと考えています。提供しているコードは Learning-to-Rank for Lucene というオープンソースで筆者が書いたコードに基づいています。
ランキング学習
LambdaMART はランキング学習(Learning to Rank)のアルゴリズムです。ランキング学習は、情報検索において多用されており、yes か no、クラス・関連度予測ではなく、複数の文書のランキング(表示順)に基づいた学習です。例えば、以下二つの文書があるとしましょう(星の数が多いほど関連度が高い)。
文書A, 星の数:5
文書B, 星の数:4
文書C, 星の数:3
学習済みのモデルでそれぞれの文書に星の数を予測しようとしたら、以下の結果が出たとします。
文書A, 予測値:4
文書B, 予測値:2
文書C, 予測値:1
上記の予測値は、期待の星の数とは異なりますが、星の数の高い順で並ぶと、表示順が変わりません。ユーザにとっては結果の表示順が理想と全く同じです。従って、文書のランキングに基づいて学習し、学習されたモデルを文書のランキングで評価すべきではないかというイデアが生まれ、ランキング学習のアルゴリズムが開発されました。
LambdaMART
さて、本題に入りましょう。LambdaMARTは、どういう機械学習のアルゴリズムかというと、勾配ブースティング回帰木(Gradient Boosted Regression Tree)のアルゴリズムです。具体的には LambdaMART はニューラルネットワークのアルゴリズム「LambdaRank」の勾配を用いて、MART(Multiple Added Regression Trees)というブースティングアルゴリズムで作る回帰木のアンサンブルの各木のスコアを計算します。LambdaMART の回帰木のイメージは以下の通りです。
本記事では、LambdaMART の擬似コード、数式を参考にして実際にそれらを Java コードに落として行きたいと考えます。まずは回帰木の作成を実装してから回帰木のリーフ(出力ノード)のスコア計算を実装します。評価の実装は読者に任せます。
回帰木の作成
前章で述べたように、LambdaMART は MART の回帰木作成ルールに従い回帰木のアンサンブルを作ります。構築する回帰木は一般の決定木のように、出力ノード以外各ノードに閾値と子ノードがあり、あるインスタンスのある特徴量が閾値超えるとそのインスタンスが右の子ノードへ、超えない場合は左の子ノードへ移ります。出力ノードにたどり着くまで繰り返します。アンサンブルの各木の出力の総和がモデルの出力となります。木の数、また各木のリーフ数もユーザに指定されたとします。一つのいたレーションで木を構築するときに各ノードの特徴量と閾値を決定するところを以下解説します。最初のレーション(木がまだないとき)の出力を0とします。
まずは、木とノードクラスを作ります。
public class Node {
private final Node parent;
private Node left;
private Node right;
private double threshold;
private double score;
private int featureId;
private final List<Document> scoredDocs; // this にたどり着いた、または this を通ったインスタンスのリスト
private final int nodeId;
public static makeRoot(int featureId, double threshold, List<Document> data) {
return new Node(featureId, threshold, data)
}
private Node(int featureId, double threshold, List<Document> data) { //ルートのコンストラクタ
this.source = null;
this.featureId = featureId;
this.threshold = threshold;
this.scoredDocs = data; //全データがルートを通る
this.score = 0.0; //ルートはスコアを出力しない
List<Document> leftDocs = new ArrayList<>();
List<Document> rightDocs = new ArrayList<>();
for (Document doc : data) {
if (doc.getFeature(featureId) < threshold) leftDocs.add(doc);
else rightDocs.add(doc);
}
left = new Node(this, leftDocs, 1);
right = new Node(this, rightDocs, 2);
}
private Node(Node source, List<Document> scoredDocs, int nodeId) { //ルート以外のノード
if (scoredDocs.isEmpty())
throw new IllegalArgumentException();
this.source = source;
this.scoredDocs = scoredDocs;
this.nodeId = nodeId;
left = null;
right = null;
score = 0.0;
threshold = Double.NEGATIVE_INFINITY;
featureId = -1;
}
public void addNode(int feature, double threshold) {
assert(this.featureId == -1 && this.threshold == Double.NEGATIVE_INFINITY);
assert(left == null && right == null);
List<Document> leftDocs = new ArrayList<>();
List<Document> rightDocs = new ArrayList<>();
for (Document doc : this.scoredDocs) {
if (doc.getFeature(featureId) < threshold) leftDocs.add(doc);
else rightDocs.add(doc);
}
leftLeaf = new Node(this, leftDocs, 2 * leafId + 1);
rightLeaf = new Node(this, rightDocs, (2 * leafId) + 2);
this.featureId = feature;
this.threshold = threshold;
}
public double score(List<Double> features) {
if ((left == null || right == null))
return score;
Node destination = features.get(featureId) < threshold ? left : right;
return destination.score(features);
}
public List<Node> getLeaves() { //繰り返しにおける実装も可
List<Node> terminalLeaves = new ArrayList<>();
if (!this.hasChildren()) {
terminalLeaves.add(this);
return terminalLeaves;
}
getDestinations().forEach(leaf -> terminalLeaves.addAll(leaf.getTerminalLeaves()));
return terminalLeaves;
}
//必要なセッター、ゲッター
}
public class RegressionTree {
private final Node root;
private static void constructTree(Node root) {
// 以下で実装
}
public RegressionTree(int numLeaves, int initFeat, double initThreshold, List<Document> data) {
assert(numLeaves > 2); //ルート作成でリーフ2枚が作られる
root = Node.makeRoot(initFeat, initThreshold, data);
constructTree(root);
}
public double predict(List<Double> features) {
return root.score(features);
}
}
public class Ensemble {
private final List<RegressionTree> trees;
public Ensemble() {
trees = new ArrayList<>();
}
public double predict(List<Double> features) {
return trees.stream().mapToDouble(tree -> tree.predict(features)).sum();
}
}
本章では RegressionTree.constructTree() を実装します。 木の作成時、以下の関数を最小化する特徴量と閾値を選びます。
あるノードに子ノードが2個あるとします。左の子は L で 右の子 は R とします。S はある特徴量(特徴量 a と呼ぶ)で計算し、i は文書i のことを指します。ある閾値(t と呼ぶ)において、特徴量a が t 以下の文書は L 、そうでない文書は R に所属します。μ は L, または R の全文書のラベルの平均値です。
上記の関数の二つの和(Σ の項目)を計算するメッソドを作ります。
protected double calcLossSum(List<Document> subData){
if (subData.size() == 0)
return 0;
double avg = subData.stream().mapToDouble(doc -> doc.getLabel()).sum() / subData.size(); //μ
return subData.stream().mapToDouble(doc -> Math.pow(doc.getLabel() - avg, 2)).sum();
}
各特徴量において、S が最小値になる閾値を計算します。 以下、受け取る sortedData が事前に feat という特徴量で低い順でソートされていることを前提とします。 (ステップサイズで探索することで高速化可能)
private static double[] findThreshLoss(List<Document> sortedData, int feat) {
// 特徴量 feat の低い順でソートされた sortedData
if (sortedData.getMaxFeatureValue(feat) == sortedData.getMinFeatureValue(feat))
return new double[Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY];
int numDocs = sortedData.size();
double candidateThreshold = sortedData.get(0).getFeature(feat); // getMinFeatureValue(feat) とイコール
double minLoss = Double.POSITIVE_INFINITY;
for (int threshId = 0; threshId < numDocs; threshId++) {
if (threshId != numDocs -1 && fSortedDocs.get(threshId).getFeature(feat) == fSortedDocs.get(threshId + 1).getFeature(feat))
continue;
List<Document> lDocs = new ArrayList<>(fSortedDocs.subList(0, threshId + 1));
List<Document> rDocs = new ArrayList<>(fSortedDocs.subList(threshId + 1, numDocs));
double loss = calcLossSum(lDocs) + calcLossSum(rDocs);
if(loss < minLoss){
threshold = !rDocs.isEmpty() ? rDocs.get(0).getFeature(feat) : lDocs.get(lDocs.size() - 1).getFeature(feat);
minLoss = loss;
}
}
return new double[]{threshold, minLoss}
}
上記メソッドを用いて全特徴量において最適解を求めて、その中から S が最小の特徴量と閾値を選択します。 最適解に関する情報を格納する以下のクラスを戻り値とします。
public class OptimalFeatureThreshold {
private final int optimalFeature;
private final double minLoss;
private final double optimalThreshold;
public OptimalFeatureThreshold( int optimalFeature, double optimalThreshold, double minLoss){
this.optimalFeature = optimalFeature;
this.optimalThreshold = optimalThreshold;
this.minLoss = minLoss;
}
public int getOptimalFeature() { return optimalFeature; }
public double getMinLoss() { return minLoss; }
public double getOptimalThreshold() { return optimalThreshold; }
}
public static OptimalFeatureThreshold findOptimalFeatThresh(List<List<Document>> sortedDataList) {
int optimalFeat = 0;
List<Document> sortedData = sortedDataList.get(0); //特徴量 0 の低い順でソートされた文書のリスト
double[] optimalThreshLoss = findThreshLoss(sortedData);
for (int currentFeat = 1; currentFeat < sortedDataList.size(); currentFeat++) {
sortedData = sortedDataList.get(currentFeat);
double[] currentThreshLoss = findThreshLoss(sortedDataList);
if (currentThreshLoss[1] < optimalThreshLoss[1]) {
optimalFeat = currentFeat;
optimalThreshLoss = currentThreshLoss;
}
}
return new OptimalFeatureThreshold(optimalFeat, optimalThreshLoss[0], optimalThreshLoss[1]);
}
上記の findOptimalFeatThresh を用いて、全文書を左、もしくは右の子ノードへ移動させ、左右でそれぞれの文書リストを使い、指定されたリーフ数まで繰り返すと回帰木が完成します。まとめると、RegressionTree.constructTree() は以下のように実装できます。
public class RegressionTree {
//... 省略
private Node findOptimalLeaf(Map<Node, OptimalFeatureThreshold> nodeThresholdMap) {
Iterator<Map.Entry<Node, OptimalFeatureThreshold>> iterator = nodeThresholdMap.entrySet().iterator();
Map.Entry<Node, OptimalFeatureThreshold> optimalEntry = iterator.next();
while(iterator.hasNext()){
Map.Entry<Node, OptimalFeatureThreshold> nextEntry = iterator.next();
if(nextEntry.getValue().getMinLoss() < optimalEntry.getValue().getMinLoss())
optimalEntry = nextEntry;
}
return optimalEntry.getKey();
}
private static void constructTree(Node root) {
Map<Node, OptimalFeatureThreshold> nodeErrorMap = new HashMap<>();
for(int l = 2; l < numLeaves; l++) {
for (Node leaf : root.getTerminalLeaves()) {
if(!nodeErrorMap.containsKey(leaf)) //ダイナミックにスプリット(子ノードを追加)するノードを見つける
nodeErrorMap.put(leaf, treeTools.findMinLeafThreshold(leaf.getScoredDocs(), numSteps));
}
Node optimalLeaf = findOptimalLeaf(nodeErrorMap);
int feature = nodeErrorMap.get(optimalLeaf).getOptimalFeature();
double threshold = nodeErrorMap.get(optimalLeaf).getOptimalThreshold();
optimalLeaf.addNode(feature, threshold);
nodeErrorMap.remove(optimalLeaf);
}
}
//...省略
}
回帰木のスコア計算
一回のイタレーションで回帰木の作成が終わったら、次に今まで作ってきた回帰木のアンサンブルを用いて学習し、新しくできた回帰木のスコアをセットします。LambdaMART の論文では、スコア計算のための以下の数式があります。
上記の γkm
は m 番目の木の k 番目のリーフのスコアとなります。
上記の ∂C / ∂si
は ρ
で表してありますが、Λに切り替えましょう。
σ
は 1 とします。
本記事では ΔZ
を ΔNDCG
とします。NDCG
の方程式は以下の通りです。
NDCG の計算を実装します。
public class NDCG {
public static double dcg(List<Document> docRanks, int position) {
//docRanks はランキング順・表示順でソートされていること前提とします。
double sum = 0;
if (position > -1) {
final int pos = Math.min(position, docRanks.size());
for (int i = 0; i < pos; i++) {
sum += (Math.pow(2, docRanks.get(i).getLabel() - 1)) / Math.log(i + 2);
}
}
return sum * Math.log(2); //底の変換公式
}
//Ideal Discounted Cumulative Gain (理想のDCG)
public static double idcg(List<Document> docList, int position) {
List<Document> docsRanks = new ArrayList<>(docList);
docsRanks.sort(Comparator.comparingInt(Document::getLabel).reversed());
return dcg(docList, position);
}
//Normalized Discounted Cumulative Gain (正規化された DCG)
public static double ndcg(List<Document> docList, int position) {
return dcg(docList, position) / idcg(docList, position);
}
}
毎回 Math.pow(2, docRanks.get(i).getLabel() - 1)
や1 / Math.log(i + 2)
は少し効率が良くないので、簡単に以下のように実装します(もちろんヘルパークラスによる実装は可能):
public class LambdaMartTrainer {
//コンストラクタで全特徴量において findOptimalFeatThresh()を実行し以下
//の thresholds を初期化する
private final double[][] thresholds; // {{最適閾値0, 最小損失0},
// {最適閾値1}, 最小損失1}...}
private static double sigmoid(double input) {
return 1 / (1 + Math.exp(-input));
}
public static int findMinLossFeat(double[][] thresholds, double minLoss){
//thresholds から初期特徴量・閾値を探る
int feat = -1;
double loss = Double.POSITIVE_INFINITY;
for (int fid = 0; fid < thresholds.length; fid++){
if(thresholds[fid][1] < loss && thresholds[fid][1] > minLoss){
loss = thresholds[fid][1];
feat = fid;
}
}
return feat;
}
public void train() {
// (yi - 1) * (yi - 1)
//y = 正解・ラベル idcg 計算用
Map<Document, Double> pws = new HashMap<>();
double minLoss = 0;
for(Document doc : trainingDocs) pws.put(doc,Math.pow(2, doc.getLabel()) - 1 );
HashMap<Document, Double> ranks = new HashMap<>(); // si, sj etc... (今まで作ってきたアンサンブルの出力)
HashMap<Document, Double> lambdas = new HashMap<>(); //Λi, Λj etc...
HashMap<Document, Double> logs = new HashMap<>(); // 1 / (log (ランクi - 1)); あるクエリーの全ての s 従属
HashMap<Document, Double> lambdaDers = new HashMap<>(); //∂Λ / ∂s
for (int t = 1; t <= numTrees; t++){
int minLossFeat = findMinLossFeat(thresholds, minLoss);
if(minLossFeat == -1){
System.out.printf("Stopping early at tree %d \n", t);
return;
}
//今回のイタレーションにおけるΛ計算
//trainingPairs は、一つのクエリーに対してラベルが異なる文書ペアのリスト。
//全文書のラベルが同様の場合、null
for (int queryId = 0; queryId < trainingSet.size(); queryId++) {
if (trainingPairs.get(queryId) == null) //スキップ;
continue;
Query query = trainingSet.get(queryId); //trainingSet は訓練データセット
double N = NDCG.idcg(query.getDocList(), query.getDocList().size());
List<Document> sorted = ensemble.sort(query);
for (int i = 0; i < sorted.size(); i++) { // Dynamic Programming 用
Document doc = sorted.get(i);
ranks.put(doc, ensemble.predict(doc.getFeatures()));
lambdas.put(doc, 0d);
lambdaDers.put(doc, 0d);
logs.put(doc, 1 / Math.log(i + 2));
}
for (Document[] pair : trainingPairs.get(queryId)) {
//ΔNDCG
double dNCG = (pws.get(pair[0]) - pws.get(pair[1])) * (logs.get(pair[0]) - logs.get(pair[1])) / N;
double diff = ranks.get(pair[1]) - ranks.get(pair[0]); //- (si - sj) ; sigmoid has minus sign
double lambda = sigmoid.output(diff) * dNCG;
double lambdaDer = lambda * (1 - (lambda/dNCG));
lambdas.put(pair[0], lambdas.get(pair[0]) - lambda); //λ1 = λ1 - dλ
lambdas.put(pair[1], lambdas.get(pair[1]) + lambda); //λ2 = λ2 - dλ
lambdaDers.put(pair[0], lambdaDers.get(pair[0]) - lambdaDer);
lambdaDers.put(pair[1], lambdaDers.get(pair[1]) + lambdaDer);
}
}
//回帰木の作成
double[] minThresholdLoss = thresholds[minLossFeat]; //minLossFeat という特徴量で事前に計算された
double minThreshold = minThresholdLoss[0]; //閾値と対応する損失を使って回帰木を作る
RegressionTree tree = new RegressionTree(numLeaves, minLossFeat, minThreshold, trainingDocs, numSteps);
ensemble.addTree(tree);
minLoss = minThresholdLoss[1]; //次回別の閾値・特徴量を選択
List<Node> terminalLeaves = tree.getTerminalLeaves();
//Λ、Λの微分を使って、回帰木のリーフにスコアをセット
for(Node leaf : terminalLeaves){
double y = leaf.getScoredDocs().stream().filter(doc -> lambdas.containsKey(doc)).mapToDouble(doc -> lambdas.get(doc)).sum();
double w = leaf.getScoredDocs().stream().filter(doc -> lambdaDers.containsKey(doc)).mapToDouble(doc -> lambdaDers.get(doc)).sum();
leaf.setScore(y / w); //学習率でかけることが可能
}
validate(t, evalK);
}
}
}
上記のコードでは学習が行えます。全く同じ回帰木を作らないように minThresholdLoss などで 毎回異なる初期特徴量・閾値を選択するという実装になっていますが、もちろん実装方法は様々あります。 以上、本記事では LambdaMART の回帰木作成と回帰木のリーフ計算に関わる数式をコードにしました。