猫でもわかるWeb開発・プログラミング

本業エンジニアリングマネージャー。副業Webエンジニア。Web開発のヒントや、副業、日常生活のことを書きます。

Tensorflow で Model.fit_generator is deprecated Please use Model.fit, which supports generators.

f:id:yoshiki_utakata:20210223112518p:plain

Tensorflow で Model.fit_generator を使ったら下記Warningが出た

WARNING:tensorflow:From <ipython-input-12-4f61d48d2ed6>:1: 
Model.fit_generator (from tensorflow.python.keras.engine.training) 
is deprecated and will be removed in a future version.
Instructions for updating:
Please use Model.fit, which supports generators.

該当のコードはこんな感じ

from tensorflow.keras.utils import Sequence

# 自前の Sequence を定義し
class MySequence(Sequence):
    ....

# モデルを構築
model = Sequential()
...(省略)

# 学習
# ここで Warning が出る
model.fit_generator(
    generator=MySequence(), 
    epochs=5, 
    verbose=1
)

fit メソッドも Sequence には対応しているので、 fit_generator を fit に置き換える。

fit メソッドの x にそのまま generator を渡すだけである。generator を渡す場合は x だけ指定すればよく、y は指定してはいけない。

model.fit(
    x=MySequence(),
    epochs=5,
    verbose=1
)