В литературе по рекомендательным системам есть хорошо известный метод исправления этой предвзятости. Этот метод коррекции смещения выборки, также известный как коррекция log-Q, описан в этой статье Google. Основная идея состоит в том, чтобы получить хорошее приближение к знаменателю soft-max (также известному как функция разделения), используя документы в мини-пакете:

q = query_embedding
d_k = k-th document embedding
d_0 = correct document embedding

partition_function
= sum over the document corpus k(exp(<q, d_k> / T)
= exp(<q, d_0> / T) + sum over the negative document corpus(exp(<q, d_k> / T * Pr[d_k in mini-batch] / Pr[d_k in mini-batch])
~= exp(<q, d_0> / T) + sum over the minibatch negatives(exp(<q, d_k> / T) / Pr[d_k in mini-batch]
= exp(<q, d_0> / T) + sum over the minibatch negatives(exp(<q, d_k> / T - log(Pr[d_k in mini-batch]))

Это вывод формулы коррекции log-Q, и в первый момент она использует простую формулу. Обновленный код обучения выглядит так:

eps = 1e-6
batch_size = query_embedding.size(0)

logits = query_embedding @ document_embedding.T / temperature
labels = torch.arange(0, batch_size, dtype=torch.int64, device=device)

log_candidate_sampling_prob = batch['log_candidate_sampling_prob'].reshape(1, -1).repeat(batch_size, 1)
# need to set the correction for the positive label to 0
log_candidate_sampling_prob.fill_diagonal_(0)

loss = F.cross_entropy(logits - log_candidate_sampling_prob, labels)

Это дает преимущества, когда изученные вложения оцениваются по всему корпусу (с использованием любого стандартного алгоритма ИНС, такого как HNSW, если корпус большой). Теперь перейдем к тому, как можно оценить столбец log_candidate_sampling_prob.

Оценка вероятности выборки кандидатов

Стандартный подход заключается в оценке униграммной вероятности документа в обучающем корпусе. Если эта вероятность равна, скажем, p, а размер мини-пакета равен N, вероятность выборки кандидата можно легко аппроксимировать с помощью формулы: 1-(1-p )^Н.

Другой популярный подход — оценка log_candidate_sampling_prob потоковым способом с использованием эскиза count-min (в столбце document_id в мини-пакете).

class StreamingLogQCorrectionModule(nn.Module):
    def __init__(
            self,
            num_buckets: int,
            hash_offset: int,
            alpha: float,
            p_init: float,
    ):
        super().__init__()
        self.num_buckets = num_buckets
        self.hash_offset = hash_offset
        self.alpha = alpha
        self.register_buffer('b', (1.0 / p_init) * torch.ones((num_buckets,), dtype=torch.float32))
        self.register_buffer('a', torch.zeros((num_buckets,), dtype=torch.long))

    def forward(self, document_ids: torch.LongTensor) -> torch.Tensor:
        h = self.hash_fn(document_ids.view(-1))
        return - self.b[h].log().reshape(*document_ids.shape)

    def hash_fn(self, document_ids: torch.LongTensor) -> torch.LongTensor:
        return (document_ids + self.hash_offset) % self.num_buckets

    def train_step(self, document_ids: torch.LongTensor, batch_idx: int) -> None:
        h = self.hash_fn(document_ids).unique()
        self.b[h] = (1 - self.alpha) * self.b[h] + self.alpha * (batch_idx - self.a[h]).float()
        self.a[h] = batch_idx

Обычно несколько таких оценок каскадно объединяются следующим образом:

class CascadedStreamingLogQCorrectionModule(nn.Module):
    def __init__(
            self,
            num_buckets: int,
            hash_offsets: Tuple[int, ...],
            alpha: float,
            p_init: float,
    ):
        super().__init__()
        self.models = nn.ModuleList([
            StreamingLogQCorrectionModule(num_buckets, offset, alpha, p_init)
            for offset in hash_offsets
        ])

    def forward(self, document_ids: torch.LongTensor) -> torch.Tensor:
        result = torch.empty((0,), device=document_ids.device)
        for i, mod in enumerate(self.models):
            if i == 0:
                result = mod(document_ids)
            else:
                result = torch.minimum(result, mod(document_ids))
        return result

    def train_step(self, document_ids: torch.LongTensor, batch_idx: int) -> None:
        for mod in self.models:
            mod.train_step(document_ids, batch_idx)

Мы должны передать столбец document_id и индекс пакета методу train_step, чтобы обновить текущую статистику журнала вероятности выборки кандидата, а затем затем получить журнал выборки кандидата. вероятность путем вызова вышеуказанного модуля оценки вероятности каскадной потоковой коррекции log-Q.

Проблемы с коррекцией log-Q

Что происходит, когда корпус документов очень велик? Что делать, если внедрения документа не являются свободными параметрами (например, скрытые внедрения)? Что, если встраивания документов являются результатом модели контента, такой как скрытое состояние LLM или «преобразователь предложений»? Значения коррекции log-Q принимают очень большое отрицательное, примерно постоянное значение для всех документов. Это настолько вредит логарифмической аппроксимации, что модель без логарифмической коррекции имеет тенденцию работать лучше, чем модель с логарифмической коррекцией.

Так как же нам это исправить? Мы используем тот факт, что в таких условиях встраивания документа ограничены n-мерным пространством (обычно n-мерной единичной сферой). В n-мерной единичной сфере существует не так много «приблизительно уникальных» вложений, особенно для n достаточно небольшого размера. Таким образом, что касается контрастных потерь, мы сильно переоцениваем поправочный коэффициент log-Q для статистической суммы, если используем обычную формулу коррекции log-Q. Поэтому мы предлагаем использовать положение вложений документа в единичной сфере для оценки коррекции log-Q вместо использования оценки поправочного коэффициента log-Q на основе идентификатора документа.

Коррекция log-Q на основе встраивания документа

Мы используем тот же CascadedStreamingLogQCorrectionModule для оценки поправочного коэффициента log-Q. Однако вместо использования столбца document_id в качестве входных данных для этой оценки мы используем Locality Sensitive Hash (LSH) встраивания документа для оценки поправочного коэффициента log-Q.

LSH (различного разрешения) для n-мерной единичной сферы можно рассчитать с помощью следующего модуля с соответствующими значениями для n_proj и num_bins :

class LocalitySensitiveHashingModule(nn.Module):
    def __init__(self, emb_dim: int, n_proj: int, num_bins: int):
        super().__init__()
        self.register_buffer(
            'projection_mat',
            F.normalize(torch.randn((emb_dim, n_proj)), p=2.0, dim=0),
            persistent=True,
        )
        resolution = 2.0 / num_bins
        self.register_buffer(
            'grid',
            torch.linspace(-1, 1, num_bins + 1)[:-1] + 0.5 * resolution,
            persistent=True,
        )
        self.num_bins = num_bins

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = F.normalize(x, p=2.0, dim=-1) @ self.projection_mat
        z = torch.bucketize(z, self.grid).long()
        result = torch.empty((0,), device=x.device, dtype=torch.long)
        for i, t in enumerate(z.unbind(dim=-1)):
            if i == 0:
                result = t
            else:
                result = result * (1 + self.num_bins) + t
        return result

Этот подход может привести к даже лучшим результатам оценки ИНС для всего корпуса, чем стандартная коррекция log-Q с использованием document_id. Неясно, как развивается динамика обучения, поскольку модель документа также можно обучать вместе с моделью запроса. В таких случаях может потребоваться тщательно настроить параметр памяти alpha в CascadedStreamingLogQCorrectionModule.

Рекомендации