Keras の出力で異なる活性化関数を使いそれぞれの損失関数を指定する

画像を認識する機械学習のプログラムを勉強しています。Keras を使って、次のことが学習できるかためしてみました。

  • 画像中に物体が存在するかどうか。
  • 物体が存在するならば、その物体を囲む長方形の位置。

このモデルで、出力に異なる 2 種類の 活性化関数 を使い、それぞれに 損失関数 を指定してみました。

作成したプログラムは、以下のところにあります。

このプログラムは、以下のプログラムを参考にしました。

入力は、プログラムで乱数を使って作成した 8x8 の画像を使います。この入力を 200 ニューロンDense で処理します。出力は、次のようにします。

  • 物体が存在するかどうかは、活性化関数 として シグモイド関数 を使います。損失関数binary_crossentropy を使います。
  • 物体を囲む長方形の位置は、左上の x 位置と y 位置、幅、高さの 4 つの値で表します。これらの 活性化関数 は線形結合にします。損失関数 は、以下のようにします。
    • 物体が存在するならば、mean_absolute_error
    • 物体が存在しないならば、0

モデルは、Function API を使って作成します。

# Build the model.
from keras.models import Model
from keras.layers import Input, Dense, Dropout
# from keras.optimizers import SGD

inputs = Input(shape=(X.shape[-1], ), name='Input')
x = Dense(200, activation='relu', name='Dense_1')(inputs)
x = Dropout(0.2, name='Dropout_1')(x)
exists = Dense(1, activation='sigmoid', name='exists')(x)
bbox = Dense(4, activation='linear', name='bbox')(x)

model = Model(inputs=inputs, outputs=[exists, bbox])
model.compile(
    'adadelta', 
    loss={'exists': exists_loss, 'bbox': bboxes_loss})
    # loss={'exists': 'binary_crossentropy', 'bbox': 'mean_absolute_error'})
model.summary()

このモデルの出力は、以下の 2 つを組み合わせたものです。

それぞれの出力の損失関数は、exists_loss と bboxes_loss になります。これらの定義は次のようになります。

from keras.losses import binary_crossentropy, mean_absolute_error
import tensorflow as tf

y_loss_true_exists = None
    
def exists_loss(y_true, y_pred):
    global y_loss_true_exists
    y_loss_true_exists = y_true[ : , 0]
    return binary_crossentropy(y_true, y_pred)

def bboxes_loss(y_true, y_pred):
    global y_loss_true_exists
    return y_loss_true_exists * mean_absolute_error(y_true, y_pred)

損失関数の引数 y_true, y_pred ですが、それぞれの出力に対応する値が渡されます。exists_loss には exists の正解と予測値、bboxes_loss には bbox の正解と予測値が渡されます。bboxes_loss で exists の正解の値を使いたいのですが、これは渡されません。

この問題を解決するため、exists_loss で exists の正解をグローバル変数で保存しておくことにします:-( 損失関数の呼ばれる順番ですが、Model の output で指定した出力の順番になります。そのため、exists_loss がまず呼び出され、その次に bboxes_loss が呼ばれます。

bboxes_loss は、物体が存在するかどうかで結果を変える必要があります。このために if 文で場合分けするのではなく、それぞれのデータで物体が存在する場合は 1、存在しない場合は 0 になるようにして、そのベクトルを mean_absolute_error の値のベクトルに掛けるようにしました。

損失関数に渡される y_true, y_pred ですが、最初の次元はバッチです。トレーニングしているバッチのデータが、順にまず並んでいます。その次の次元に実際のデータが入ります。exists_loss では、2 番目の次元の最初の値が物体が存在するかどうかを示す値になります。この値を bbox の予測と正解から計算した mean_absolute_error の値に掛けます。そうすると、物体が存在するときの値はそのまま残り、物体が存在しないときは 0 になります。