MATLAB の bsxfun に親しむ
1 月から 3 月にかけて Coursera の機械学習コースを受講しましたが、毎週の課題として提出したプログラムの中で、MATLAB の bsxfun 関数を利用する機会が何度かありました。bsxfun 関数を利用すると、ある種の処理をループに頼らず効率的に記述できます。今回は簡単な例を用いて bsxfun 関数の使い方を見てみます。
bsxfun というのは Binary Singleton eXpansion FUNction の略で、大きさの異なる二つの多次元配列に対して、要素数が 1 である次元を拡張*1しながら、要素ごとの二項演算を適用するものです。マニュアルの以下のページに説明があります。
大きさが 1 の次元の拡張を有効にして 2 つの配列に要素ごとのバイナリ演算を適用 - MATLAB bsxfun - MathWorks 日本
簡単な例として、bsxfun を用いて掛け算の九九を計算してみます*2。ここでは、9 行 1 列の配列 x' と 1 行 9 列の配列 x に要素ごとの掛け算 @times を適用して 9 行 9 列の行列を得ています。
>> x = 1:9 x = 1 2 3 4 5 6 7 8 9 >> bsxfun(@times, x', x) ans = 1 2 3 4 5 6 7 8 9 2 4 6 8 10 12 14 16 18 3 6 9 12 15 18 21 24 27 4 8 12 16 20 24 28 32 36 5 10 15 20 25 30 35 40 45 6 12 18 24 30 36 42 48 54 7 14 21 28 35 42 49 56 63 8 16 24 32 40 48 56 64 72 9 18 27 36 45 54 63 72 81
bsxfun 関数による計算の様子を下図に示します。青色の箇所が x' を、橙色の箇所が x を表します。それぞれの配列を相手の大きさに合わせて拡張したうえで、要素ごとの掛け算が適用されます。二つの配列の対応する各次元は同じ大きさであるか 1 でなければいけません。なお、bsxfun 関数による配列の「拡張」は実際に配列が複製されるわけではなく、repmat 関数を用いて同等の計算を行うよりも効率的です。
bsxfun の処理では、各要素に個別に関数が適用されるとは限らない点に注意が必要です。つまり、上記の九九の例では @times が 81 回適用されるわけではありません。このことを具体的に確認してみます。以下のように mytimes 関数を定義します。この関数は、引数 a, b の大きさを画面に出力してから要素ごとの掛け算を実行します。
function c = mytimes(a, b) fprintf('size(a) = [%d %d], size(b) = [%d %d]\n', size(a), size(b)); c = a .* b; end
mytimes 関数を使って九九を計算させてみると、次の出力が得られます*3。x' の方は 9 行 1 列の配列のまま mytimes 関数に渡されており、mytimes は 9 回実行されていることがわかります。
>> bsxfun(@mytimes, x', x); size(a) = [9 1], size(b) = [1 1] size(a) = [9 1], size(b) = [1 1] size(a) = [9 1], size(b) = [1 1] size(a) = [9 1], size(b) = [1 1] size(a) = [9 1], size(b) = [1 1] size(a) = [9 1], size(b) = [1 1] size(a) = [9 1], size(b) = [1 1] size(a) = [9 1], size(b) = [1 1] size(a) = [9 1], size(b) = [1 1]
bsxfun による計算はこのように実行されるため、上記の mytimes のように自作の関数を用いる場合には、その引数がスカラーであることを前提にしてはいけません。マニュアルにも以下のように記載されているとおりです。
fun はスカラー拡張もサポートしなければなりません。たとえば、A または B がスカラーの場合、C は他の入力配列内のすべての要素にスカラーを適用した結果となるようにしなければなりません。
以下、このことを考慮しない実装による失敗例を見てみます。平面上の 4 点 p1, p2, p3, p4 について、各点間の距離を 4 行 4 列の行列に求める計算を考えます。次のように 4 点を定義して、全体を行列 P にまとめておきます。
>> p1 = [0 0]; >> p2 = [1 0]; >> p3 = [1 1]; >> p4 = [2 2]; >> P = [p1; p2; p3; p4] P = 0 0 1 0 1 1 2 2
先に答えを確認してしまうと、これらの各点間の距離は MATLAB の関数を利用すれば次のように計算できます。
>> squareform(pdist(P)) ans = 0 1.0000 1.4142 2.8284 1.0000 0 1.0000 2.2361 1.4142 1.0000 0 1.4142 2.8284 2.2361 1.4142 0
これを bsxfun で計算することを考えます。まず、インデックス i, j を受け取り、P(i,:) と P(j,:) の距離を求める関数を定義します。以下のように正しく距離を計算できます。
>> mydist = @(i, j) sqrt(sum((P(i,:) - P(j,:)) .^ 2, 2)); >> mydist(1, 1) % 0 >> mydist(1, 2) % 1 >> mydist(1, 3) % 1.4142 >> mydist(1, 4) % 2.8284
ところが、bsxfun を用いて mydist を (1:4)', 1:4 に適用する方法は上手くいきません。実行すると次のようなエラーメッセージが出力されます。
>> bsxfun(mydist, (1:4)', 1:4) エラー: - 行列の次元は一致しなければなりません。 エラー: @(i,j)sqrt(sum((P(i,:)-P(j,:)).^2))
mytimes 関数で確認したように、i, j は必ずしもスカラーになるとは限りません。この例でも i には (1:4)' がそのまま渡され、行列 P 全体から P の第 j 行を引こうとして次元が合わないため、前述のようなエラーとなります*4。