Я хочу реализовать «полноценный» алгоритм градиентного спуска в наборе данных, используя resnet. Однако, поскольку в полнопакетном градиенте на каждой итерации нам нужно вычислять градиент по всем точкам обучения, когда я использую tf.GradinetTape, возникает ошибка OOM. Это функция, которая вычисляет градиент:
@tf.function
def compute_gradients(training_ds): # this function computer gradients for other functions
image, label = training_ds
with tf.GradientTape() as tape:
predictions = model(image)
loss = loss_fn(label, predictions)
grad = tape.gradient(loss, model.trainable_variables)
return grad
Однако у меня ошибка OOM, так как сеть очень большая. Как я могу эффективно вычислить градиент для полного градиентного спуска?
CPU
, а не наGPU
. - person Tensorflow Warrior   schedule 26.05.2020