Машинное обучение
Как добавить новые данные в предварительно обученную модель в Scikit-learn
Пошаговое руководство по использованию warm_start=True и partial_fit() в scikit-learn
Когда вы создаете модель машинного обучения с нуля, обычно вы разделяете свой набор данных на обучающий и тестовый наборы, а затем обучаете свою модель на своем обучающем наборе. Затем вы проверяете производительность своей модели на своем тестовом наборе, и если вы получаете что-то приличное, вы можете использовать свою модель для прогнозирования.
Но что, если в какой-то момент станут доступны новые данные?
Другими словами, как обучить уже обученную модель? Или опять же, как добавить новые данные в уже обученную модель?
В этой статье я попытаюсь дать некоторые ответы на этот нетривиальный вопрос, используя библиотеку scikit-learn. Вы можете прочитать эту интересную статью Vidhi Chugh, чтобы понять, когда вам нужно переобучить вашу модель.
Одним из возможных (тривиальных) решений предыдущего вопроса может быть обучение модели с нуля с использованием как старых, так и новых данных. Однако это решение не масштабируется, если первое обучение требует длительного времени.
Решение проблемы — добавить образцы в уже обученную модель. И этот scikit-learn позволяет вам это сделать в некоторых случаях. Просто соблюдайте некоторые меры предосторожности.
Scikit-learn предлагает две стратегии:
- частичная подгонка
- теплый старт
Чтобы проиллюстрировать, как добавлять новые данные в предварительно обученную модель в Scikit-learn, я буду использовать практический пример, используя хорошо известный набор данных iris, предоставляемый библиотекой Scikit-learn.
теплый старт
Теплый старт — это параметр, предоставляемый некоторыми моделями Scikit. Если для него установлено значение True, он позволяет использовать существующие атрибуты подогнанной модели для инициализации новой модели в последующем вызове подгонки.
Например, вы можете установить warm_start = True
в классификаторе случайного леса, тогда вы сможете регулярно подгонять модель. Если вы снова вызовете метод подгонки для новых данных, к существующим деревьям будут добавлены новые оценки. Это означает, что использование warm_start = True не меняет существующие деревья.
warm_start = True
следует не использовать для постепенного изучения новых наборов данных, где может быть дрейф концепций. Дрейф концепции — это тип дрейфа в модели данных, который происходит, когда изменяется базовая связь между выходными и входными переменными.
Чтобы понять, как работает warm_start = True
, я описываю пример. Идея состоит в том, чтобы показать, что использование warm_start = True может улучшить производительность алгоритма, если я добавлю новые данные, которые имеют то же распределение, что и исходные данные, и которые поддерживают те же отношения с выходной переменной.
Во-первых, я загружаю набор данных iris, предоставленный библиотекой Scikit-learn:
from sklearn import datasets iris = datasets.load_iris() X = iris.data y = iris.target
Затем я разделил набор данных на три части:
X_train
,y_train
— обучающая выборка 80% из 40% данных (48 выборок)X_test
,y_test
— тестовый набор 20% из 40 данных (12 выборок)X2, y2
— новые образцы (60% данных) (90 образцов)
from sklearn.model_selection import train_test_split X1, X2, y1, y2 = train_test_split(X, y, test_size=0.60, random_state=42) X_train, X_test, y_train, y_test = train_test_split(X1, y1, test_size=0.20, random_state=42)
Я буду использовать X2
и y2
для переобучения модели.
Обратите внимание, что обучающая выборка очень мала (48 образцов).
Я тренирую модель с warm_start = False
:
from sklearn.ensemble import RandomForestClassifier model = RandomForestClassifier(max_depth=2, random_state=0, warm_start=False, n_estimators=1) model.fit(X_train, y_train)
Подсчитываю балл:
model.score(X_test, y_test)
который дает следующий результат:
0.75
Теперь я подогнал модель к новым данным:
model.fit(X2, y2)
Предыдущая подгонка удаляет уже изученную модель. Затем я подсчитываю баллы:
model.score(X_test, y_test)
который дает следующий результат:
0.8333333333333334
Теперь я создаю новую модель с параметром warm_start = True, чтобы посмотреть, увеличится ли оценка модели.
model = RandomForestClassifier(max_depth=2, random_state=0, warm_start=True, n_estimators=1) model.fit(X_train, y_train) model.score(X_test, y_test)
который дает следующий результат:
0.75
Теперь я подгоняю модель и вычисляю оценку:
model.n_estimators+=1 model.fit(X2, y2) model.score(X_test, y_test)
который дает следующий результат:
0.9166666666666666
Инкрементальное обучение улучшило оценку!
частичная подгонка
Вторая стратегия, предоставляемая Scikit-learn для добавления новых данных в предварительно обученную модель, — это использование метода partial_fit()
. Не все модели поддерживают этот метод.
Хотя параметр warm_start = True
не изменяет параметры атрибута, уже изученные моделью, частичное соответствие может изменить его, поскольку оно учится на новых данных.
Я снова рассматриваю набор данных радужной оболочки.
Теперь я использую SGDClassifier
:
from sklearn.linear_model import SGDClassifier import numpy as np model = SGDClassifier() model.partial_fit(X_train, y_train, classes=np.unique(y))
При первом запуске метода partial_fit() я должен передать этому методу также все классы. В этом. Например, я предполагаю, что знаю все классы, содержащиеся в y, хотя у меня недостаточно образцов для их представления.
Подсчитываю балл:
model.score(X_test, y_test)
который дает следующий результат:
0.4166666666666667
Теперь я добавляю в модель новые образцы:
model.partial_fit(X2, y2)
и я вычисляю счет:
model.score(X_test, y_test)
который дает следующий результат:
0.8333333333333334
Добавление новых данных улучшило производительность алгоритма!
Краткое содержание
Поздравляем! Вы только что узнали, как добавлять новые данные в предварительно обученную модель в Scikit-learn! Вы можете использовать либо параметр warm_start
, установленный на True
, либо метод partial_fit()
. Однако не все модели в библиотеке Scikit-learn позволяют добавлять новые данные в предварительно обученную модель. Поэтому я предлагаю проверить документацию!
Вы можете скачать код, использованный в этом руководстве, из моего репозитория Github.
Если вы дочитали до этого места, то для меня это уже много на сегодня. Спасибо! Вы можете прочитать мои популярные статьи по этой ссылке.