полный пакетный градиентный спуск с tf.GradientTape: ошибка OOM

Я хочу реализовать «полноценный» алгоритм градиентного спуска в наборе данных, используя resnet. Однако, поскольку в полнопакетном градиенте на каждой итерации нам нужно вычислять градиент по всем точкам обучения, когда я использую tf.GradinetTape, возникает ошибка OOM. Это функция, которая вычисляет градиент:

@tf.function
def compute_gradients(training_ds):  # this function computer gradients for other functions
    image, label  = training_ds
    with tf.GradientTape() as tape:
      predictions = model(image)
      loss = loss_fn(label, predictions)    
    grad = tape.gradient(loss, model.trainable_variables)
return grad

Однако у меня ошибка OOM, так как сеть очень большая. Как я могу эффективно вычислить градиент для полного градиентного спуска?


person Mahdi    schedule 24.05.2020    source источник
comment
Из-за ограничений памяти вам нужно работать с меньшими входными данными ИЛИ уменьшить размер сети. Также убедитесь, что вы работаете на CPU, а не на GPU.   -  person Tensorflow Warrior    schedule 26.05.2020
comment
@TensorflowWarriors: спасибо за ваш комментарий. Не могли бы вы привести пример сценария использования этого подхода? Я новичок в ТФ2. Я реализовал вашу идею, но когда я украсил функцию tf.function, она выдает ошибку.   -  person Mahdi    schedule 27.05.2020
comment
С какой ошибкой вы столкнулись? Также поделитесь полным кодом здесь ИЛИ добавьте свой код в файл Google Colab и поделитесь ссылкой здесь.   -  person Tensorflow Warrior    schedule 27.05.2020