Используйте алгоритм GraphSAGE в сочетании с встраиванием слов OpenAI для повышения точности классификации последующих документов.

В последние годы обработка естественного языка (NLP) быстро развивалась. Одним из важных аспектов этого прогресса было использование вложений, которые представляют собой числовые представления слов или фраз, которые фиксируют их значение и отношения с другими словами в языке. Вложения можно использовать в широком спектре задач НЛП, таких как классификация документов, машинный перевод, анализ тональности и распознавание именованных сущностей. Кроме того, с появлением больших предварительно обученных языковых моделей, таких как GPT-3, встраивания стали еще более важными для обеспечения возможности переноса обучения для ряда языковых задач и областей. Таким образом, встраивания тесно связаны с быстрым развитием НЛП.

С другой стороны, недавние достижения в области графов и графовых нейронных сетей привели к повышению производительности при решении широкого круга задач, включая распознавание изображений, поиск лекарств и рекомендательные системы. В частности, графовые нейронные сети продемонстрировали большие перспективы в изучении представлений графоструктурированных данных, где отношения между точками данных обеспечивают сигнал, повышающий точность последующих задач машинного обучения.

В этом сообщении блога вы узнаете, как использовать возможности графических нейронных сетей для захвата и кодирования взаимосвязей между точками данных и повышения точности классификации документов. В частности, вы будете обучать две модели прогнозированию тегов средней статьи.

Большинство статей среднего размера имеют соответствующие теги, назначенные им автором для облегчения обнаружения и эффективности поиска. Кроме того, вы можете рассматривать эти теги как категоризацию статей. Каждая статья может иметь до 5 тегов или категорий, к которым она принадлежит, как показано на изображении выше. Поэтому вы научите две модели классификации выполнять классификацию по нескольким меткам, где каждой статье может быть присвоен один или несколько тегов.

Первая модель классификации будет использовать последние вложения OpenAI (text-embedding-ada-002) заголовка и подзаголовка статьи в качестве входных признаков. Эта модель обеспечит базовую точность, которую вы попытаетесь улучшить с помощью алгоритма графовой нейронной сети под названием GraphSAGE. Интересно, что вложения слов могут быть и будут использоваться в этом примере в качестве входных данных для GraphSAGE. Затем во время обучения алгоритм GraphSAGE использует эти вложения слов для итеративной агрегации информации из соседних узлов, что приводит к мощным представлениям на уровне узлов, которые могут повысить точность последующих задач машинного обучения, таких как классификация документов.

Короче говоря, в этом сообщении блога исследуется использование графовых нейронных сетей для улучшения встраивания слов с учетом взаимосвязей между точками данных. Когда отношения между точками данных актуальны и предсказуемы, графовые нейронные сети могут изучать более содержательные и точные представления текстовых данных и, следовательно, повышать точность последующих моделей машинного обучения.

Средний набор данных

На Kaggle доступно несколько наборов данных статей среднего размера. Однако ни один из них не содержит отношений между статьями. Какой тип взаимосвязей между статьями может быть предсказательным для предсказания их тегов? Medium добавил возможность для пользователей создавать списки, которые могут помочь им добавлять в закладки и выбирать контент, который у них есть или который они собираются прочитать.

На этом изображении представлен пример, когда пользователь создал четыре списка статей на основе их тем. Например, большинство статей было сгруппировано в списке Наука о данных, тогда как другие статьи были добавлены в список Общение, Математика и Дизайн. списки. Идея состоит в том, что если две статьи находятся в одном списке, они несколько более похожи, чем если бы у них не было общих списков. Средние списки можно рассматривать как аннотированные человеком отношения между статьями, которые могут помочь вам найти и потенциально порекомендовать похожие статьи.

Есть одно исключение из этого предположения. Некоторые пользователи создают обширные списки для чтения, содержащие всевозможные статьи.

Интересно, что большинство этих списков с большим количеством статей имеют одинаковое название Список для чтения. Так что это должно быть какое-то значение по умолчанию для Medium или что-то в этом роде, поскольку я заметил заголовок списка для чтения у нескольких пользователей.

К сожалению, общедоступных наборов данных с информацией о статьях Medium, а также списках пользователей, которым они принадлежат, нет. Поэтому мне пришлось потратить полдня на разбор данных. Я получил информацию о 55 тысячах медийных статей из 4000 списков пользователей.

Подготовка среды Neo4j

Построение графа и обучение GraphSAGE будет выполняться в Neo4j. Мне нравится Neo4j, так как он предлагает хорошо разработанный язык запросов к графам под названием Cypher, а также плагин Graph Data Science, который содержит более 50 графических алгоритмов, которые охватывают большую часть рабочего процесса графовой аналитики. Поэтому нет необходимости использовать несколько инструментов для создания и анализа графика.

Схема графика набора данных Medium, если:

Схема вращается вокруг статей Medium. Нам известны url, название и дата статьи. Кроме того, я рассчитал встраивания OpenAI, используя модель text-embedding-ada-002 на основе заголовка статьи и подзаголовка, и сохранил их как свойство openaiEmbedding. Кроме того, мы знаем, кто написал статью, к каким спискам пользователей она принадлежит и какие у нее теги.

Я подготовил для вас два варианта импорта среднего набора данных в базу данных Neo4j. Вы можете выполнить следующую записную книжку Jupyter и импортировать набор данных из Python. Этот параметр также работает с средой песочницы Neo4j (используйте проект по науке о данных с пустым графиком).



Другой вариант — восстановить подготовленный мной дамп базы данных Neo4j.



Дамп был создан с помощью Neo4j версии 5.5.0, поэтому обязательно используйте эту версию или более позднюю. Самый простой способ восстановить дамп базы данных — воспользоваться средой рабочего стола Neo4j. Кроме того, вам потребуется установить библиотеки APOC и GDS, если вы используете среду рабочего стола Neo4j.

После завершения импорта базы данных вы можете запустить следующую инструкцию Cypher в браузере Neo4j, чтобы убедиться, что импорт прошел успешно.

MATCH p=(n:Author)-[:WROTE]->(d)-[:IN_LIST]->(), p1=(d)-[:HAS_TAG]->()
WHERE n.name = "Tomaz Bratanic"
RETURN p,p1 LIMIT 25

Результат будет содержать пару статей, которые я написал вместе с их списками и тегами.

Теперь пришло время для практической части этой записи в блоге. Весь код анализа доступен в виде Jupyter Notebook.



Исследовательский анализ

Мы будем использовать Graph Data Science Python Client для взаимодействия с Neo4j и его плагином Graph Data Science. Это отличное дополнение к экосистеме Neo4j, позволяющее нам выполнять графовые алгоритмы с использованием чистого кода Python. Прочтите мой вводный пост в блоге для получения дополнительной информации.

Во-первых, мы оценим распределение тегов по статьям-носителям.

dist_df = gds.run_cypher("""
MATCH (a:Article)
RETURN count{(a)-[:HAS_TAG]->()} AS count
""")

sns.displot(dist_df['count'], height=6, aspect=1.5)

Около 50% статей не имеют тегов. На это есть две причины. Либо автор их не использовал, либо процесс удаления не смог их получить по разным причинам, например, в публикациях на носителях, имеющих пользовательские структуры HTML. Однако это не имеет большого значения, поскольку у нас все еще есть более 25 тысяч статей с их тегами, что позволяет нам обучать и оценивать модель классификации тегов статей с несколькими метками. Большинство авторов предпочитают использовать пять тегов в статье, что также является верхним пределом, который позволяет платформа Medium.

Далее мы оценим, не входят ли какие-либо статьи в какие-либо списки пользователей.

gds.run_cypher(
    """
MATCH (a:Article)
RETURN exists {(a)-[:IN_LIST]-()} AS in_list,
       count(*) AS count
ORDER BY count DESC
"""
)

Результаты показывают, что все статьи принадлежат хотя бы одному списку. Идентификация изолированных узлов (узлов без соединения) является важной частью любого рабочего процесса графовой аналитики, поскольку мы должны уделять им особое внимание при расчете встраивания узлов. К счастью, этот набор данных не содержит изолированных узлов, так что нам не о чем беспокоиться.

В последней части исследовательского анализа мы рассмотрим наиболее часто встречающиеся теги. Здесь мы построим облако слов из тегов, присутствующих не менее чем в 100 статьях.

tags = gds.run_cypher(
    """
MATCH (t:Tag)
WITH t, count {(t)<--()} AS size
WHERE size > 100
RETURN t.name AS tag, size
ORDER BY size DESC
"""
)

d = {}
for i, row in tags.iterrows():
    d[row["tag"]] = row["size"]

wordcloud = WordCloud(
    background_color="white", colormap="tab20c", min_font_size=1
).generate_from_frequencies(d)
plt.figure()
plt.imshow(wordcloud)
plt.axis("off")
plt.show()

Наиболее частыми тегами являются наука о данных, искусственный интеллект, программирование и машинное обучение.

Многоуровневая классификация

Как уже упоминалось, мы будем обучать модель классификации с несколькими метками для прогнозирования тегов статьи на Medium. Поэтому мы будем использовать библиотеку scikit-multilearn, которая поможет с разделением данных и обучением модели.

Я заметил, что разделение набора данных с помощью библиотеки scikit-multilearn не предоставляет случайный начальный параметр, и поэтому разделение набора данных не является детерминированным. Для правильного сравнения базовой модели, обученной на встраивании слов OpenAI, и модели, основанной на встраиваниях GraphSAGE, мы выполним разделение одного набора данных, чтобы обе версии модели использовали одни и те же обучающие и тестовые примеры. В противном случае могут быть некоторые различия в точности моделей, основанные исключительно на разделении набора данных.

Вложения слов уже сохранены в графе, поэтому нам нужно только вычислить вложения узлов с помощью алгоритма GraphSAGE, прежде чем мы сможем обучить модели классификации.

ГрафикSAGE

GraphSAGE — это алгоритм нейронной сети сверточных графов. Ключевая идея алгоритма заключается в том, что мы изучаем функцию, которая генерирует вложения узлов, путем выборки и агрегирования информации об объектах из локального окружения узла. Поскольку алгоритм GraphSAGE изучает функцию, которая может вызвать встраивание узла, его также можно использовать для создания вложений нового узла, который не наблюдался на этапе обучения. Это называется индуктивным обучением.

Исследование окрестностей и обмен информацией в GraphSAGE. [1]

Если вы хотите узнать больше о процессе обучения и математике, лежащей в основе алгоритма GraphSAGE, я предлагаю вам взглянуть на запись в блоге Интуитивное объяснение GraphSAGE Ризы Озчелик или на официальный сайт GraphSAGE.

Однодольная проекция с алгоритмом сходства узлов

GraphSAGE поддерживает графы с несколькими типами узлов, где каждый тип узла имеет различные функции, представляющие его. В нашем примере у нас есть узлы Статья и Список. Однако я решил упростить рабочий процесс, выполнив одночастную проекцию.

Однодольная проекция — частый этап анализа графов. Идея состоит в том, чтобы взять двудольный граф (граф с двумя типами узлов) и вывести однодольный граф (граф только с одним типом узлов). В этом конкретном примере мы можем создать связь между двумя статьями, если они являются частью одного и того же списка. Кроме того, количество общих списков или нормализованное значение, такое как коэффициент Жаккара, можно сохранить как свойство отношения.

Поскольку однодольная проекция является обычным шагом в анализе графов, библиотека Neo4j Graph Data Science предлагает алгоритм сходства узлов, который поможет нам в этом.

Во-первых, нам нужно спроецировать график в памяти. Мы включим узлы Статья и Список вместе с отношениями IN_LIST. Кроме того, мы включим свойства узла openaiEmbedding.

G, metadata = gds.graph.project(
    "articles", 
    ["Article", "List"],
    "IN_LIST", 
    nodeProperties=["openaiEmbedding"]
)

Теперь мы можем выполнить однодольную проекцию, используя алгоритм сходства узлов. Следует отметить, что значение параметра topK по умолчанию равно 10, что означает, что каждый узел будет подключен только к своим десяти наиболее похожим узлам. Однако в этом примере мы хотим создать связь между всеми статьями в списке пользователей. Поэтому мы будем использовать относительно высокое значение параметра topK.

gds.nodeSimilarity.mutate(
    G, topK=2000, mutateProperty="score", mutateRelationshipType="SIMILAR"
)

Мы использовали режим алгоритма mutate, который сохраняет результаты обратно в спроецированный граф в памяти. Отношение SIMILAR было создано между всеми парами статей, которые имеют хотя бы один общий список пользователей.

Обучение модели GraphSAGE

Алгоритм GraphSAGE является индуктивным, что означает, что его можно использовать для создания вложений для узлов, которые ранее не были видны во время обучения. Индуктивный характер позволяет нам обучать модель GraphSAGE только на подмножестве графа, а затем генерировать вложения для всех узлов. Обучение модели GraphSAGE только на подмножестве графа экономит время и вычислительную мощность, что полезно при работе с большими графами. Хотя наш график не такой большой, мы можем использовать этот пример, чтобы продемонстрировать, как эффективно выбирать обучающее подмножество графика.

Случайный блуждание с выборкой перезапусков

Идея случайного блуждания с выборкой перезапусков довольно проста. Алгоритм совершает случайные обходы из набора предопределенных начальных узлов. На каждом шаге блуждания существует вероятность того, что текущее случайное блуждание остановится и начнется новое из множества начальных узлов. Пользователь может определить начальные узлы. Если начальные узлы не определены, алгоритм выбирает их случайным образом.

Я подумал, что было бы интересно показать вам пример выбора начального узла вручную. Итак, мы начнем с выполнения алгоритма Weakly Connected Components, чтобы оценить, насколько связен граф статей. Слабосвязный компонент — это набор узлов в графе, где существует путь между всеми узлами в наборе, если направление отношений игнорируется.
Слабосвязный компонент можно рассматривать как остров, до которого не могут добраться узлы из других компонентов. .
Хотя алгоритм идентифицирует связанные наборы узлов, его выходные данные могут помочь вам оценить, насколько несвязанным является общий граф.

wcc = gds.wcc.stream(G)
wcc_grouped = (
    wcc.groupby("componentId")
    .size()
    .to_frame("componentSize")
    .reset_index()
    .sort_values("componentSize", ascending=False)
    .reset_index()
)
print(wcc_grouped)

Всего в нашем графе 604 компонента связности. Самый большой компонент содержит 98% всех узлов, в то время как другие меньше, причем многие из них содержат только два узла. Если компонент содержит только два узла, это означает, что у нас есть средний список пользователей, в котором есть только две статьи, и эти две статьи не являются частью каких-либо других списков.

Мы выполнили алгоритм Weakly Connected Component, чтобы определить узел, который принадлежит большому компоненту связности и, следовательно, может использоваться в качестве начального узла алгоритма выборки. Например, если бы мы использовали узел только с одним соседом, алгоритм выборки не мог бы выполнять более длительные обходы для эффективной подвыборки графа.

К счастью, реализован алгоритм выборки, который автоматически расширяет набор начальных узлов, если случайные блуждания не посещают новые узлы. Однако, поскольку мы использовали начальный узел из самого большого связанного компонента с 98% всех узлов, алгоритму не нужно будет автоматически расширять набор начальных узлов.

largest_component = wcc_grouped["componentId"][0]
start_node = wcc[wcc["componentId"] == largest_component]["nodeId"][0]

trainG, metadata = gds.alpha.graph.sample.rwr(
    "trainGraph",
    G,
    samplingRatio=0.20,
    startNodes=[int(start_node)],
    nodeLabels=["Article"],
    relationshipTypes=["SIMILAR"],
)

Параметр коэффициента выборки определяет долю узлов исходного графа для выборки. Например, при использовании значения 0,20 для коэффициента выборки размер выборочного подграфа будет составлять 20 % от размера исходного графа. Кроме того, нам нужно определить, что случайные блуждания могут посещать узлы Article только через отношения SIMILAR с использованием nodeLabels и relationshipTypes. сильные> параметры.

Обучение GraphSAGE

Наконец, мы можем продолжить обучение модели GraphSAGE на выбранном подграфе.

gds.beta.graphSage.train(
    trainG,
    modelName="articleModel",
    embeddingDimension=256,
    sampleSizes=[10, 10],
    searchDepth=15,
    epochs=20,
    learningRate=0.0001,
    activationFunction="RELU",
    aggregator="MEAN",
    featureProperties=["openaiEmbedding"],
    batchSize=10,
)

Алгоритм GraphSAGE будет использовать свойство узла openaiEmbedding в качестве входных признаков. Вложения GraphSAGE будут иметь размерность 256 (размер вектора). Пока я экспериментировал с оптимизацией гиперпараметров для этого блога, я заметил, что скорость обучения и функция активации являются наиболее важными параметрами.

Создание вложений

После обучения модели GraphSAGE мы можем использовать ее для расчета вложений узлов для всех узлов Article в исходном более крупном спроецированном графе и рассматривать только отношения SIMILAR.

gds.beta.graphSage.write(
    G,
    modelName="articleModel",
    nodeLabels=["Article"],
    writeProperty="graphSAGE",
    relationshipTypes=["SIMILAR"],
)

На этот раз мы использовали режим записи для сохранения вложений GraphSAGE в качестве свойств узла в базе данных.

Модель классификации

Мы подготовили вложения как OpenAI, так и GraphSAGE. Осталось только обучить модели и сравнить их производительность.

Во-первых, мы пометим теги статей, которые мы хотим предсказать. Я произвольно решил включить только те теги, которые присутствуют как минимум в 100 статьях. Целевые теги будут помечены вторичной меткой Target.

gds.run_cypher(
    """
MATCH (t:Tag)
WHERE count{(t)<--()} > 100
SET t:Target
RETURN count(*) AS count
"""
)

Мы пометили 161 тег, который хотим предсказать. Помните, что приведенная выше визуализация облака слов использовала те же 161 тег и визуализировала их в соответствии с их частотой.

Поскольку мы будем использовать библиотеку scikit-multilearn, нам необходимо экспортировать соответствующую информацию из Neo4j.

data = gds.run_cypher(
    """
MATCH (a:Article)-[:HAS_TAG]->(tag:Target)
RETURN a.url AS article,
        a.openaiEmbedding AS openai,
        a.graphSAGE AS graphSAGE,
        collect(tag.name) AS tags
"""
)

Далее нам нужно построить бинарную матрицу, указывающую на наличие тегов для данной статьи. По сути, вы можете думать об этом как о горячем кодировании тегов для каждой статьи. Итак, для этого мы можем использовать процедуру MultiLabelBinarizer.

mlb = MultiLabelBinarizer()
tags_mlb = mlb.fit_transform(data["tags"])
data["target"] = list(tags_mlb)

Библиотека scikit-multilearn предлагает улучшенное разделение набора данных для задач прогнозирования с несколькими метками. Однако он не допускает детерминированного подхода со случайным начальным параметром. Поэтому мы выполним разделение набора данных только один раз для встраивания слов и GraphSAGE, а затем соответствующим образом обучим две модели.

Следующая функция принимает фрейм данных и столбцы, которые должны использоваться отдельно в качестве входных признаков для модели классификации с несколькими метками, и возвращает наиболее эффективную модель при печати взвешенного макроса и взвешенной точности. Здесь мы используем подход LabelPowerset к классификации с несколькими метками.

def train_and_evaluate(df, input_columns):
    max_weighted_precision = 0
    best_input = ""
    # Single split data
    X = data[input_columns].values
    y = np.array(data["target"].to_list())
    x_train_all, y_train, x_test_all, y_test = iterative_train_test_split(
        X, y, test_size=0.2
    )
    # Train a model for each input option
    for i, input_column in enumerate(input_columns):
        print(f"Training a model based on {input_column} column")
        x_train = np.array([x[i] for x in x_train_all])
        x_test = np.array([x[i] for x in x_test_all])

        # train
        classifier = LabelPowerset(LogisticRegression())
        classifier.fit(x_train, y_train)
        # predict
        predictions = classifier.predict(x_test)
        print("Test accuracy is {}".format(accuracy_score(y_test, predictions)))
        print(
            "Macro Precision: {:.2f}".format(
                get_macro_precision(mlb.classes_, y_test, predictions)
            )
        )
        weighted_precision = get_weighted_precision(mlb.classes_, y_test, predictions)
        print("Weighted Precision: {:.2f}".format(weighted_precision))
        if weighted_precision > max_weighted_precision:
            max_weighted_precision = weighted_precision
            best_classifier = classifier
            best_input = input_column

    return best_classifier, best_input

Когда все подготовлено, мы можем приступить к обучению моделей на основе вложений слов и графов SAGE и сравнить их производительность.

пс. Если вы используете Google Colab, вы можете столкнуться с проблемами OOM при использовании встраивания openai

classifier, best_input = train_and_evaluate(data, ["openai", "graphSAGE"])

Результаты следующие:

Training a model based on openai column
Test accuracy is 0.055443548387096774
Macro Precision: 0.20
Weighted Precision: 0.36
Training a model based on graphSAGE column
Test accuracy is 0.05584677419354839
Macro Precision: 0.30
Weighted Precision: 0.41

Хотя встраивание заголовка и подзаголовка предоставляет некоторую информацию о своих тегах, оно может быть не самым эффективным. Это может быть связано с заголовками в стиле кликбейтов, в которых приоритет отдается привлечению внимания, а не точному описанию контента. Кроме того, у авторов могут быть разные предпочтения в отношении маркировки одинакового контента разными метками. Несмотря на эти проблемы, наша модель предсказывает 161 метку, многие из которых имеют несколько примеров, что дает приемлемые результаты. Для дальнейшего повышения точности мы можем встроить весь текст статьи и оценить его эффективность.

Интересно, что использование вложений GraphSAGE повышает точность классификации за счет учета взаимосвязей между статьями. Макро-точность нашей модели улучшается на десять процентных пунктов, а взвешенная точность улучшается на пять. Эти результаты демонстрируют, что встраивания GraphSAGE помогают более эффективно идентифицировать нечастые теги. В отличие от стандартных моделей встраивания слов, графовые нейронные сети позволяют нам кодировать дополнительные отношения между точками данных, тем самым улучшая последующие модели машинного обучения. Мы также уменьшили размерность с 1536 до 256, увеличив при этом производительность, что является отличным результатом.

Тестовые предсказания

В нашей базе почти 50% статей без тегов. Мы можем протестировать модель на нескольких и вручную оценить результаты.

example = gds.run_cypher(
    """
MATCH (a:Article)
WHERE NOT EXISTS {(a)-[:HAS_TAG]->()}
RETURN a.title AS title,
       a.openaiEmbedding AS openai,
       a.graphSAGE AS graphSAGE
LIMIT 15
"""
)

tags_predicted = classifier.predict(np.array(example[best_input].to_list()))
example["tags"] = [list(mlb.inverse_transform(x)[0]) for x in tags_predicted]
example[["title", "tags"]]

Результаты

Интересно, что модель в основном присваивает каждой статье один или два ярлыка, в то время как большинство реальных статей имеют пять тегов. Вероятно, это одна из причин значений показателей точности. Кроме того, результаты выглядят многообещающе, судя по этой небольшой выборке.

Краткое содержание

Традиционные модели встраивания слов, такие как word2vec, фокусируются на кодировании статистики совпадения слов. Однако они полностью игнорируют любые другие отношения, которые можно найти между точками данных. Например, у нас были пользователи, которые комментировали похожие статьи, помещая их в различные списки для чтения. К счастью, графовые нейронные сети предлагают мост между традиционными встраиваниями слов и вложениями графов, поскольку они позволяют нам строить поверх встраивания слов и кодировать дополнительную информацию, полученную из отношений между точками данных. Таким образом, графовые нейронные сети не обязательно начинать с нуля, их можно использовать для улучшения современных вложений слов или документов.

Рекомендации

[1] Гамильтон, Уилл, Читао Ин и Юре Лесковец. «Индуктивное репрезентативное обучение на больших графах. Достижения в области нейронных систем обработки информации. 2017.»