May 2016
Volume 31 Number 5
テストの実行 - 多腕バンディット問題
ラスベガスで 3 台のスロット マシンの前に立っているとします。手持ちの硬貨は 20 枚、3 台のうちいずれかのスロット マシンに硬貨を投入してハンドルを引けば、何某かの金額が払い戻されます。スロット マシンの払い戻しはそれぞれ決まっておらず、最初はどのようなスケジュールで払い戻しが行われるかわかりません。どうすれば、最高額の払い戻しを受けられるでしょう。
これがいわゆる「多腕バンディット問題」の例で、スロット マシンの俗称が「one-armed bandit (片腕バンディット)」だったことに由来しています。多腕バンディット問題は見かけほど風変わりなものではありません。スロット マシンの例に似た重要な問題は、薬の治験など身近にたくさんあります。
企業の開発シナリオで、多腕バンディット問題を実装するよう求められた開発者はあまりいないでしょう。ですが、このコラムを読んでみようとお考えになった理由なら 3 つほど思いつきます。まず、このコラムで取り上げるプログラミング手法には、他の一般的なプログラミング シナリオに流用できるものがあるためです。次に、多腕バンディット問題の具体的なコード実装は、経済が活発な分野や機械学習研究への入門書として利用できるためです。最後に、多腕バンディット問題のトピック自体に興味をそそられたためかもしれません。
このコラムの目的を理解するには、図 1 のデモ プログラムを見るのが一番です。多腕バンディット問題には、さまざまなアルゴリズムを使用できます。たとえば、完全にランダムなアプローチとして、毎回スロット マシンを無作為に 1 台選んでハンドルを引き、最高のスロット マシンを探すやり方です。今回紹介するデモでは、「探索と活用アルゴリズム」という基本的な手法を使用しています。
図 1. 多腕バンディット問題の探索と活用アルゴリズムの使用
デモでは、まず 3 台のスロット マシンを作成します。各スロット マシンはハンドルを引くたびにランダムな金額を払い戻します。この払い戻しは指定された平均値と標準偏差を使ったガウス分布 (鐘状分布) に従います。操作の平均払い戻し金額 (0.1 単位) が最高額になっている点から、3 台目のスロット マシンが狙い目です。実際のシナリオでは、スロット マシンの払い戻し特性はわかりません。
手持ちの硬貨の枚数から操作できる総回数は 20 に設定しています。探索と活用アルゴリズムでは、この総回数からある一部分を取り出し、そこから最高のスロット マシンを見つけ出します。その後の操作では、この事前探索で見つかった最高のスロット マシンだけを使用します。探索と活用アルゴリズムの重要な変数は、探索段階に使用する操作回数の割合です。今回のデモでは探索段階の割合を 0.40 に設定しています。したがって、20 * 0.40 = 8 回を探索段階に、20 - 8 = 12 回を活用段階に振り分けます。探索段階の割合を増やすと、最高のスロット マシンが見つかる確率は高まりますが、活用段階で最高のスロット マシンを利用する操作回数が減少します。
探索段階の 8 回の操作に対して、ランダムに選択したスロット マシンと、関連する払い戻し金額を表示します。その間、各スロット マシンの合計払い戻し金額も保存しています。スロット マシン 0 は 3 回選択され、-0.09 + 0.12 + 0.29 = +0.32 単位の払い戻しを行っています。スロット マシン 1 は 2 回選択され、-0.46 + -1.91 = -2.37 単位の払い戻しを行っています。スロット マシン 2 は 3 回選択され、0.19 + 0.14 + 0.70 = +1.03 単位の払い戻しを行っています。このデモの場合、スロット マシン 2 の払い戻し金額が最高になるため、探索と活用アルゴリズムはスロット マシン 2 を最高のスロット マシンとして正しく特定します。この時点で、アルゴリズムの総獲得 (損失) 額は 0.32 + -2.37 + 1.03 = -1.02 単位です。
活用段階の 12 回の操作には、スロット マシン 2 だけを繰り返し操作します。その結果の払い戻し金額は 0.03 + 0.33 + . . + 0.45 = +2.32 単位です。したがって、全 20 回の払い戻し総額は -1.02 + 2.32 = +1.30 単位になり、1 回の平均払い戻し金額は 1.30 / 20 = 0.065 になります。
多腕バンディット アルゴリズムの有効性評価に使用できる測定基準は複数あります。一般的な尺度の 1 つはリグレットと呼ばれます。リグレットとは、基準となる理論上の払い戻し総額と、アルゴリズムが求めた払い戻し総額の差のことです。基準となる理論上の払い戻し金額とは、定められた操作回数をすべて最高のスロット マシンで実行した場合に期待される払い戻し金額です。デモの 3 台のスロット マシンのうち、最高のスロット マシンの平均払い戻し金額は 0.10 単位なので、最高のスロット マシンで 20 回の操作をすべて実行した場合に期待される払い戻し額は 20 * 0.10 = 2.00 単位です。探索と活用アルゴリズムの払い戻し総額は 1.30 単位にすぎないため、リグレットは 2.00 - 1.30 = 0.70 単位です。アルゴリズムとしては、リグレット値が低いほど優れています。
今回は、少なくとも中級レベルのプログラミング スキルがあることを前提としますが、多腕バンディット問題の知識は問いません。スペースを節約するために少し編集したデモ プログラムを 図 2 に示します。また、付属のコード ダウンロードから入手することもできます。デモ は C# を使用してコーディングしていますが、デモを Python や Java などの別の言語にリファクタリングしても大きな問題は起きません。また、多腕バンディット問題のメインとなる考え方が明確になるように、デモから通常のエラー チェックをすべて省略しました。
図 2. 多腕バンディットのデモ コード一式
using System;
namespace MultiBandit
{
class MultiBanditProgram
{
static void Main(string[] args)
{
Console.WriteLine("\nBegin multi-armed bandit demo \n");
Console.WriteLine("Creating 3 Gaussian machines");
Console.WriteLine("Machine 0 mean = 0.0, sd = 1.0");
Console.WriteLine("Machine 1 mean = -0.5, sd = 2.0");
Console.WriteLine("Machine 2 mean = 0.1, sd = 0.5");
Console.WriteLine("Best machine is [2] mean pay = 0.1");
int nMachines = 3;
Machine[] machines = new Machine[nMachines];
machines[0] = new Machine(0.0, 1.0, 0);
machines[1] = new Machine(-0.5, 2.0, 1);
machines[2] = new Machine(0.1, 0.5, 2);
int nPulls = 20;
double pctExplore = 0.40;
Console.WriteLine("Setting nPulls = " + nPulls);
Console.WriteLine("\nUsing pctExplore = " +
pctExplore.ToString("F2"));
double avgPay = ExploreExploit(machines, pctExplore,
nPulls);
double totPay = avgPay * nPulls;
Console.WriteLine("\nAverage pay per pull = " +
avgPay.ToString("F2"));
Console.WriteLine("Total payout = " +
totPay.ToString("F2"));
double avgBase = machines[2].mean;
double totBase = avgBase * nPulls;
Console.WriteLine("\nBaseline average pay = " +
avgBase.ToString("F2"));
Console.WriteLine("Total baseline pay = " +
totBase.ToString("F2"));
double regret = totBase - totPay;
Console.WriteLine("\nTotal regret = " +
regret.ToString("F2"));
Console.WriteLine("\nEnd bandit demo \n");
Console.ReadLine();
} // Main
static double ExploreExploit(Machine[] machines,
double pctExplore, int nPulls)
{
// Use basic explore-exploit algorithm
// Return the average pay per pull
int nMachines = machines.Length;
Random r = new Random(2); // which machine
double[] explorePays = new double[nMachines];
double totPay = 0.0;
int nExplore = (int)(nPulls * pctExplore);
int nExploit = nPulls - nExplore;
Console.WriteLine("\nStart explore phase");
for (int pull = 0; pull < nExplore; ++pull)
{
int m = r.Next(0, nMachines); // pick a machine
double pay = machines[m].Pay(); // play
Console.Write("[" + pull.ToString().PadLeft(3) + "] ");
Console.WriteLine("selected machine " + m + ". pay = " +
pay.ToString("F2").PadLeft(6));
explorePays[m] += pay; // update
totPay += pay;
} // Explore
int bestMach = BestIdx(explorePays);
Console.WriteLine("\nBest machine found = " + bestMach);
Console.WriteLine("\nStart exploit phase");
for (int pull = 0; pull < nExploit; ++pull)
{
double pay = machines[bestMach].Pay();
Console.Write("[" + pull.ToString().PadLeft(3) + "] ");
Console.WriteLine("pay = " +
pay.ToString("F2").PadLeft(6));
totPay += pay; // accumulate
} // Exploit
return totPay / nPulls; // avg payout per pull
} // ExploreExploit
static int BestIdx(double[] pays)
{
// Index of array with largest value
int result = 0;
double maxVal = pays[0];
for (int i = 0; i < pays.Length; ++i)
{
if (pays[i] > maxVal)
{
result = i;
maxVal = pays[i];
}
}
return result;
}
} // Program class
public class Machine
{
public double mean; // Avg payout per pull
public double sd; // Variability about the mean
private Gaussian g; // Payout generator
public Machine(double mean, double sd, int seed)
{
this.mean = mean;
this.sd = sd;
this.g = new Gaussian(mean, sd, seed);
}
public double Pay()
{
return this.g.Next();
}
// -----
private class Gaussian
{
private Random r;
private double mean;
private double sd;
public Gaussian(double mean, double sd, int seed)
{
this.r = new Random(seed);
this.mean = mean;
this.sd = sd;
}
public double Next()
{
double u1 = r.NextDouble();
double u2 = r.NextDouble();
double left = Math.Cos(2.0 * Math.PI * u1);
double right = Math.Sqrt(-2.0 * Math.Log(u2));
double z = left * right;
return this.mean + (z * this.sd);
}
}
// -----
} // Machine
} // ns
デモ プログラムを作成するには、Visual Studio を起動して、C# コンソール アプリケーション テンプレートを選択します。プロジェクトには「MultiBandit」という名前を付けます。今回は Visual Studio 2015 を使用しましたが、このデモは .NET のバージョンにあまり依存しないため、どのバージョンの Visual Studio でも機能します。
テンプレート コードが読み込まれたら、ソリューション エクスプローラー ウィンドウで Program.cs ファイルを右クリックし、名前を「MultiBanditProgram.cs」というわかりやすい名前に変更します。Visual Studio によってクラスの名前が自動的に「MultiBandit」に変更されます。エディター ウィンドウのコード上部にある不要な using ステートメントを、最上位レベルの System 名前空間を参照するステートメントを除いてすべて削除します。
すべての制御ロジックは、ExploreExploit メソッドを呼び出す Main メソッドにあります。デモには、プログラム定義の Machine クラスがあり、プログラム定義が入れ子になった Gaussian クラスもあります。
冒頭で WriteLine ステートメントを表示した後、3 台のスロット マシンを作成します。
int nMachines = 3;
Machine[] machines = new Machine[nMachines];
machines[0] = new Machine(0.0, 1.0, 0);
machines[1] = new Machine(-0.5, 2.0, 1);
machines[2] = new Machine(0.1, 0.5, 2);
Machine クラスのコンストラクターは、平均払い戻し金額、払い戻し金額の標準偏差、および乱数生成に使用するシード値の 3 つの引数を受け取ります。つまり、machine [1] は毎回の操作で平均 -0.5 単位を払い戻し、ほとんど (約 68%) の払い戻し金額は -0.5 - 2.0 = -2.5 単位と -0.5 + 2.0 = +1.5 単位の間になります。0 または正の金額を払い戻す実際のスロット マシンと異なり、デモ マシンは負の金額も払い戻しできます。
3 台のスロット マシンで探索と活用アルゴリズムを実行するステートメントを以下に示します。
int nPulls = 20;
double pctExplore = 0.40;
double avgPay = ExploreExploit(machines, pctExplore, nPulls);
double totPay = avgPay * nPulls;
ExploreExploit メソッドは、無作為に選んだ nPulls 回の操作を行った後、1 回の操作あたりの平均獲得 (損失) 金額を返します。したがって、そのセッションの払い戻し総額は、操作回数に平均払い戻し金額を乗算した金額になります。ExploreExploit の別の設計として、平均払い戻し金額ではなく払い戻し総額を返してもかまいません。
リグレットは以下のように計算します。
double avgBase = machines[2].mean;
double totBase = avgBase * nPulls;
double regret = totBase - totPay;
avgBase 変数は、最高のスロット マシンの平均払い戻し金額で、machine [2] = 0.1 単位です。したがって、20 回の操作後に期待される平均払い戻し総額は、20 * 0.10 = 2.0 単位になります。
デモの各スロット マシンは、ガウス分布 (別称、正規分布または鐘状分布) に従って金額を払い戻します。たとえば、machine [0] には、平均払い戻し金額 0.0 単位と、標準偏差 1.0 単位が設定されています。ガウス値を生成するデモ コードを使用して、machine [0] から 100 回のランダムな払い戻しを行う短いプログラムを作成しました。結果を 図 3 のグラフに示します。
図 3. 100 個のランダム ガウス値
生成された値の大部分は、平均に近い値になっています。生成値のばらつきは、標準偏差の値によって制御されます。標準偏差の値が大きくなるにつれて、値のばらつきも大きくなります。多腕バンディット問題において、すべてのアルゴリズムで最も重要な因子の 1 つは、スロット マシンの払い戻し金額のばらつきです。スロット マシンの払い戻し金額のばらつきが大きいと、本当の平均払い戻し金額を評価するのが非常に難しくなります。
指定された平均値と標準偏差を使用するガウス分布の乱数値を生成するのに使用できるアルゴリズムは複数あります。お勧めは、「ボックス ミュラー」アルゴリズムです。このアルゴリズムでは、まず、一様分散値 (.NET Math.Random クラスによって生成される値) を生成後、あるとても賢い計算を使用して、この一様分散値をガウス分布値に変換します。このアルゴリズムには、いくつかのバリエーションがあります。デモ プログラムでは、他のバリエーションと比べてあまり効率的ではありませんが、非常にシンプルなバリエーションを使用しています。
デモ プログラムでは、Gaussian クラスを Machine クラス内で定義しています。Microsoft .NET Framework では、入れ子になったクラス定義は主に、入れ子になったクラスが、外側のコンテナー クラスが使用するユーティリティ クラスのときに使用します。このデモ コードを .NET 以外の言語に移植する場合は、Gaussian クラスをリファクタリングしてスタンドアロン クラスにすることをお勧めします。Gaussian クラスには、平均払い戻し金額、払い戻しの標準偏差、および基盤となる均一乱数ジェネレーターのシード値を受け取るコンストラクターが 1 つだけあります。
デモ プログラムでは、非常に簡単な方法で Machine クラスを定義しています。以下の 3 つのクラス フィールドがあります。
public class Machine
{
public double mean; // Avg payout per pull
public double sd; // Variability about the mean
private Gaussian g; // Payout generator
...
Machine クラスは主に、ガウス乱数ジェネレーターを囲むラッパーです。さまざまな設計が考えられますが、一般には、クラスの定義をできるだけシンプルにするのがお勧めです。研究資料の中には、今回示したような標準偏差ではなく、数学的分散が使用されているものもあります。分散は標準偏差のちょうど 2 乗になるため、標準偏差でも分散でも同じです。
Machine クラスには、Gaussian ジェネレーターをセットアップするコンストラクターが 1 つあります。
public Machine(double mean, double sd, int seed)
{
this.mean = mean;
this.sd = sd;
this.g = new Gaussian(mean, sd, seed);
}
Machine クラスには、ガウス分散のランダムな払い戻し金額を返すパブリック メソッドが 1 つあります。
public double Pay()
{
return this.g.Next();
}
ガウス分散の払い戻し金額を返す代わりに、指定されたエンドポイント間で一様分散する値を返してもかまいません。たとえば、スロット マシンから -2.0 ~ + 3.0 のランダム値を返して、平均払い戻し金額が (-2 + 3) / 2 = +0.5 単位になるようにすることも考えられます。
ExploreExploit メソッドの定義は以下のように始めます。
static double ExploreExploit(Machine[] machines, double pctExplore,
int nPulls)
{
int nMachines = machines.Length;
Random r = new Random(2); // Which machine
double[] explorePays = new double[nMachines];
double totPay = 0.0;
...
Random オブジェクト r は、探索段階でスロット マシンを無作為に選択するために使用します。explorePays という配列は、探索段階での各スロット マシンの払い戻し金額を累積するために使用します。活用段階ではスロット マシンを 1 台しか使用しないので、払い戻し総額を保持するのに必要なのは変数 totPay の 1 つだけです。
次に、探索段階と活用段階の操作回数を計算します。
int nExplore = (int)(nPulls * pctExplore);
int nExploit = nPulls - nExplore;
項 (1.0 - pctExplore) を使って活用段階の操作回数を求めると、丸めの誤差によって切り捨てが行われ、間違った回数になる可能性があります。
探索段階のコードを以下に示します。WriteLine ステートメントは省略しています。
for (int pull = 0; pull < nExplore; ++pull)
{
int m = r.Next(0, nMachines); // Pick a machine
double pay = machines[m].Pay(); // Play
explorePays[m] += pay; // Update
totPay += pay;
}
Random.Next(int minVal, int maxVal) は、minVal (inclusive) と maxVal (exclusive) との間の整数値を返すため、nMachines = 3 の場合、r.Next(0, nMachines) はランダムな整数値 0、1 または 2 を返します。
次に、探索段階中に見つけた最高のスロット マシンを決め、活用段階で使用します。
int bestMach = BestIdx(explorePays);
for (int pull = 0; pull < nExploit; ++pull)
{
double pay = machines[bestMach].Pay();
totPay += pay; // Accumulate
}
プログラム定義のヘルパー メソッド BestIdx は、配列引数の中の最大値を保持するセルのインデックスを返します。多腕バンディット問題のバリエーションは数多く存在します。たとえば、探索段階に異なる方法を使って最高のスロット マシンを定義するバリエーションもあります。個人的見解ですが、このようなバリエーションの多くは、研究課題を探求するための解決策にすぎません。
ExploreExploit メソッドは、nPulls 回の操作の平均払い戻し金額を計算後、結果を返して終了です。
. . .
return totPay / nPulls;
}
他にも、平均払い戻し金額の代わりに払い戻し総額や総リグレット値を返す設計、払い戻し総額と平均払い戻し金額を 2 要素のセル配列や 2 つの出力パラメーターとして返す設計も考えられます。
研究では、すべての種類の多腕バンディット問題に最適な 1 つのアルゴリズムは存在しないとされています。アルゴリズムごとにそれぞれ長所と短所があり、問題におけるスロット マシンの台数、実行可能な操作回数、払い戻し分布関数の種類に大きく左右されます。
Dr. James McCaffrey は、ワシントン州レドモンドにある Microsoft Research に勤務しています。これまでに、Internet Explorer、Bing などの複数のマイクロソフト製品にも携わってきました。McCaffrey 博士の連絡先は、jammc@microsoft.com (英語のみ) です。
この記事のレビューに協力してくれたマイクロソフト技術スタッフの Miro Dudik および Kirk Olynyk に心より感謝いたします。