В машинном обучении постоянно встречается эта проблема - в датасете, на котором ты обучаешь нейросеть (или любую другую ML модель) - разное количество записей для разных классов. Иногда прям сильно разное. Если оставить это как есть, то нейросеть толком не выучит редкие классы и не научится их отличать.

Два основных подхода к решению этой проблемы - умножать loss для редких классов пропорционально их редкости, чтобы высокий loss заставлял сетку учиться на них и чаще читать записи редких классов, чтобы они перестали быть редкими. Эта статья про второй подход.

Итак, подготовка.

Создаем тензор датасет, в котором есть ровно 1000 нулей и 50 единичек. Нули и единички, понятное дело, представляют разные классы. Очевидно, классы не сбалансированы.

import torch
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.sampler import WeightedRandomSampler

ds = TensorDataset(torch.Tensor([ 0 for _ in range(1000)] + [ 1 for _ in range(50)]))

Теперь будем просить батчи из этого несбалансированного датасета разными способами. Батчи - это небольшие куски данных, коллекции определенной длины.

Первый способ, простейший

dl = DataLoader(ds, batch_size=50)

for _ in range(10):
    print(next(iter(dl))[0].mean())

# tensor(0.)
# tensor(0.)
# tensor(0.)
# tensor(0.)
# tensor(0.)
# tensor(0.)
# tensor(0.)
# tensor(0.)
# tensor(0.)
# tensor(0.)

Мы попросили из самого дефолтного даталоадера 10 батчей, в каждом батче по 50 цифр. Чтобы понять, из каких цифр состоят батчи - посчитали среднее для каждого из этих 10 батчей.

Очевидно, что сейчас здесь всегда нули - самый дефолтный даталоадер бежит по порядку по всем элементам. В начале у нас тысяча нулей, вот мы и получаем каждый раз в среднем 0. Просто!

Второй способ, чуть лучше

dl = DataLoader(ds, batch_size=50, shuffle=True)

for _ in range(10):
    print(next(iter(dl))[0].mean())

# tensor(0.0400)
# tensor(0.1000)
# tensor(0.0800)
# tensor(0.1400)
# tensor(0.1000)
# tensor(0.0200)
# tensor(0.0200)
# tensor(0.0400)
# tensor(0.0600)
# tensor(0.0400)

Сейчас мы получили не нули, так как мы перемешали все элементы с помощью shuffle=True и теперь в батчи иногда попадаются единички. Тоже просто!

Но средние по батчам далеки от 0.5, которое было бы, если бы мы получали нули и единички с одинаковой частотой, то есть если бы единички перестали быть редкими.

Третий способ и достижение баланса

Но все не так просто, нужна некоторая подготовка. Нам нужно знать - какая частота появления у какого класса в нашем датасете, чтобы правильно выставить веса для каждого значения:

from collections import Counter

# i.item() - вытаскивает питонье int значение из тензора
counter = Counter(i.item() for i, in ds)

weights = [1/counter.get(i.item()) for i, in ds]

Обратите внимание: изящный анпэкинг for i, in ds, мы уже писали про него ранее.

Теперь counter это Counter({0: 1000, 1: 50}), он просто посчитал уникальные элементы в датасете, а weights - это список весов для каждой цифры - можно сказать, “вероятностей” вытащить из датасета эту конкретную цифру.

Вероятность в кавычках, потому что сумма весов не обязана быть равна 1 - хитрый объект из следующего блока кода, в честь которого и названа эта статья там сам поделит на сумму весов, чтобы все правильно работало.

В нашем случае len(weights) == 1050 (ну, это понятно) и sum(weights) == 2 (в общем случае, она будет равна количеству классов, у нас класса 2 - нули и единицы)

Все, теперь мы готовы решить задачку

sampler = WeightedRandomSampler(weights, num_samples=len(weights))

dl = DataLoader(ds, batch_size=50, sampler=sampler)

for _ in range(10):
    print(next(iter(dl))[0].mean())

# tensor(0.7000)
# tensor(0.5000)
# tensor(0.2000)
# tensor(0.6000)
# tensor(0.4000)
# tensor(0.6000)
# tensor(0.6000)
# tensor(0.4000)
# tensor(0.6000)
# tensor(0.6000)

Теперь у нас в каждом батче, который мы просим, среднее получается всегда примерно 0.5 - а это значит, что нет больших редких классов, все они одинаково хорошо представлены в обучении. Вы великолепны!

Откуда это все?

То самое соревнование про симпсонов

Та же самая статья, только Kaggle ноутбук

Больше про рандом сэмплер тут

Респект ребятам из МФТИ - тык и тык

Наш телеграм канал