Добавил:
Опубликованный материал нарушает ваши авторские права? Сообщите нам.
Вуз: Предмет: Файл:
Шолле Ф. - Глубокое обучение на Python (Библиотека программиста) - 2023.pdf
Скачиваний:
4
Добавлен:
07.04.2024
Размер:
11.34 Mб
Скачать
Эти два аргумента требуют, чтобы файл модели не перезаписывался, если значение val_loss не улучшилось, что позволяет сохранить только лучшую модель

7.3. Встроенные циклы обучения и оценки    247

переобучения.и.тем.самым.избежать.повторного.обучения.модели.для.меньшего. количества.эпох..Данный.обратный.вызов.обычно.используется.в.комбинации. с.обратным.вызовом.ModelCheckpoint,.который.может.сохранять.состояние.модели.в.ходе.обучения.(и.при.необходимости.сохранять.только.лучшую.модель:. версию,.достигшую.лучшего.качества.к.концу.эпохи):

Листинг 7.19. Использование параметра callbacks метода fit()

Обратные вызовы передаются в модель

Прерывает обучение,

через параметр callbacks метода fit()

когда качество

в виде списка. Вы можете передать любое

модели перестает

количество обратных вызовов

улучшаться

 

 

callbacks_list = [

 

 

 

 

 

Сохраняет

 

 

keras.callbacks.EarlyStopping(

 

 

 

 

 

 

monitor="val_accuracy",

 

текущие веса

 

 

patience=2,

после каждой

 

),

 

 

 

 

 

эпохи

 

 

keras.callbacks.ModelCheckpoint(

 

 

 

Путь

 

 

filepath="checkpoint_path.keras",

 

 

 

 

monitor="val_loss",

 

 

 

к файлу

 

 

 

 

 

 

save_best_only=True,

 

 

 

модели

 

 

 

 

 

 

 

)

 

 

 

 

 

Следит за изменением точности модели на проверочных данных

Прерывает обучение, если точность не улучшается

в течение двух эпох

]

model = get_mnist_model() model.compile(optimizer="rmsprop",

loss="sparse_categorical_crossentropy", metrics=["accuracy"])

 

model.fit(train_images, train_labels,

Мы следим за точностью,

 

 

epochs=10,

поэтому она должна быть частью

 

callbacks=callbacks_list,

набора метрик модели

 

 

validation_data=(val_images, val_labels))

Обратите внимание: поскольку обратный вызов следит

 

за потерями и точностью на проверочных данных,

 

мы должны передать validation_data в вызов fit()

 

Помните,.что.модель.всегда.можно.сохранить.вручную.после.обучения:.нужно. лишь.вызвать.метод.model.save('путь_к_файлу')..Чтобы.загрузить.сохраненную. модель,.просто.примените:

model = keras.models.load_model("checkpoint_path.keras")

7.3.3. Разработка своего обратного вызова

Если.в.ходе.обучения.потребуется.выполнить.какие-то.особые.действия,.не.преду­ смотренные.ни.одним.из.встроенных.обратных.вызовов,.можно.написать.свой. обратный.вызов..Обратные.вызовы.реализуются.путем.создания.подкласса. класса.keras.callbacks.Callback..Вы.можете.реализовать.любые.из.следующих.

248    Глава 7. Работа с Keras: глубокое погружение

методов.с.говорящими.именами,.которые.будут.вызываться.в.соответствующие. моменты.в.ходе.обучения:

Вызывается в начале каждой эпохи

on_epoch_begin

on_epoch_end

 

Вызывается в конце каждой эпохи

 

on_batch_begin

 

 

Вызывается перед началом обработки каждого пакета

 

 

on_batch_end

 

Вызывается сразу после окончания обработки каждого пакета

 

on_train_begin

 

 

Вызывается в начале обучения

 

 

on_train_end

 

Вызывается в конце обучения

 

 

 

Все.эти.методы.вызываются.с.аргументом.logs .—.словарем,.содержащим.ин- формацию.о.предыдущем.пакете,.эпохе.или.цикле.обучения.(метрики.обучения. и.проверки.и.т..д.)..Методам.on_epoch_* .и.on_batch_* .также.передается.индекс. эпохи.или.пакета.в.первом.аргументе.(целое.число).

Вот.простой.пример.обратного.вызова,.который.сохраняет.список.значений.потерь.для.каждого.пакета.во.время.обучения.и.график.изменения.потерь.в.конце. каждой.эпохи.

Листинг 7.20. Создание своего обратного вызова наследованием класса Callback

from matplotlib import pyplot as plt

class LossHistory(keras.callbacks.Callback): def on_train_begin(self, logs):

self.per_batch_losses = []

def on_batch_end(self, batch, logs): self.per_batch_losses.append(logs.get("loss"))

def on_epoch_end(self, epoch, logs): plt.clf()

plt.plot(range(len(self.per_batch_losses)), self.per_batch_losses, label="Потери на обучающих данных для каждого пакета")

plt.xlabel(f"Пакеты (эпоха {epoch})") plt.ylabel("Потери")

plt.legend() plt.savefig(f"plot_at_epoch_{epoch}") self.per_batch_losses = []

Испытаем.его:

model = get_mnist_model() model.compile(optimizer="rmsprop",

loss="sparse_categorical_crossentropy", metrics=["accuracy"])

model.fit(train_images, train_labels, epochs=10, callbacks=[LossHistory()],

validation_data=(val_images, val_labels))

7.3. Встроенные циклы обучения и оценки    249

Сохраненный.график.можно.увидеть.на.рис..7.5.

Рис. 7.5. График, созданный нашим собственным обратным вызовом

7.3.4. Мониторинг и визуализация с помощью TensorBoard

Для.проведения.результативных.исследований.или.разработки.качественных. моделей.необходимо.иметь.разностороннюю,.часто.обновляющуюся.информацию.о.происходящем.внутри.модели.в.ходе.экспериментов..В.этом.суть.экс- периментов.—.получить.информацию.(как.можно.больше.информации).о.том,. насколько.хорошо.работает.модель..Движение.вперед.носит.итеративный,.или. циклический,.характер..Вы.начинаете.с.идеи.и.разрабатываете.план.эксперимента,.который.подтвердит.или.опровергнет.ее..Далее.вы.запускаете.эксперимент.

и.обрабатываете.полученную.информацию..Это.дает.толчок.к.рождению.новой. идеи..И.чем.больше.итераций.в.данном.цикле.вы.выполните,.тем.совершеннее.

и.мощнее.будут.становиться.ваши.идеи..Keras.поможет.вам.перейти.от.идеи. к.эксперименту.в.кратчайшие.сроки,.а.с.помощью.GPU.вы.получите.результаты. эксперимента.достаточно.быстро..Но.как.быть.с.обработкой.результатов?.Здесь. вам.пригодится.TensorBoard.(рис..7.6).

TensorBoard.(www.tensorflow.org/tensorboard).—.браузерное.приложение,.которое. можно.запускать.локально..Это.лучший.способ.наблюдения.за.происходящим. внутри.модели.во.время.обучения..TensorBoard.позволяет:

. визуально.контролировать.метрики.в.процессе.обучения;

. отображать.архитектуру.модели;

. выводить.гистограммы.активаций.и.градиентов;

. исследовать.векторные.представления.в.трехмерной.системе.координат.

250    Глава 7. Работа с Keras: глубокое погружение

Рис. 7.6. Циклическое движение вперед

Самый.простой.способ.использовать.TensorBoard.с.моделью.Keras.и.методом. fit() .—.определить.обратный.вызов.keras.callbacks.TensorBoard.

В.простейшем.случае.достаточно.указать,.куда.должна.записываться.информация. этим.обратным.вызовом,.и.все:

model = get_mnist_model() model.compile(optimizer="rmsprop",

loss="sparse_categorical_crossentropy", metrics=["accuracy"])

tensorboard = keras.callbacks.TensorBoard( log_dir="/full_path_to_your_log_dir",

)

model.fit(train_images, train_labels, epochs=10,

validation_data=(val_images, val_labels), callbacks=[tensorboard])

С.началом.обучения.модель.будет.записывать.информацию.в.указанное.местоположение..Если.обучение.выполняется.на.локальном.компьютере,.то.вы. можете.запустить.локальный.сервер.TensorBoard.следующей.командой.(обратите.внимание,.что.выполняемый.файл.tenorboard .уже.должен.быть.доступен,. если.библиотека.TensorFlow.устанавливалась.с.помощью.pip;.если.нет,.можно. установить.TensorBoard.вручную.командой.pip install tensorboard):

tensorboard --logdir /full_path_to_your_log_dir

Данная.команда.выведет.URL,.который.затем.можно.ввести.в.адресную.строку. браузера,.чтобы.получить.доступ.к.интерфейсу.TensorBoard.

Если.обучение.производится.в.блокноте.Colab,.то.можно.запустить.встроенный. экземпляр.TensorBoard.в.блокноте,.выполнив.следующую.команду:

%load_ext tensorboard

%tensorboard --logdir /full_path_to_your_log_dir

В.интерфейсе.TensorBoard.можно.наблюдать.в.режиме.реального.времени,.как. протекает.процесс.обучения.модели.(рис.7.7).