OOM on Tensorflow 2.3 only when using dataset from tensor slices

I’m trying to create a neural net that is vastly overtrained in order to transform a specific input image into a specific output image. The transformation should be doable from a pixel to pixel level so I am using a fully connected net instead of a convolutional one for the time being.

When I convert my image input data and output data into a tf.data Dataset I get an out of memory crash on my machine with 128GB, yet when I pass my inputs and outputs directly to my model.fit function it runs (even if my output is garbage). Can anyone explain why tf.data.Dataset doesn’t seem to be working for me please?

ALPHA=0.0001
EPOCHS=500
DATA_DIR='/home/tf2/Documents/NNimageconverter/separate/'

input_files=DATA_DIR + 'inputs/PA111376.resized.png'
output_files=DATA_DIR + 'outputs/PA111375.resized.png'

input_images=keras.preprocessing.image.load_img(input_files, grayscale=True)
input_images=np.asarray(input_images)


output_images=keras.preprocessing.image.load_img(output_files, grayscale=True)
output_images=np.asarray(output_images)



input_images=input_images.reshape(1,-1)
output_images=output_images.reshape(1,-1)

#input_images=input_images/255
#output_images=output_images/255

dataset=tf.data.Dataset.from_tensor_slices((input_images, output_images))


model = keras.Sequential()
#model.add(experimental.preprocessing.Rescaling(1./255))
#model.add(Flatten())
model.add(Dense(128,activation='relu'))
model.add(Dense(128, activation='relu'))
model.add(Dense(128, activation='relu'))
model.add(Dense(128, activation='relu'))
model.add(Dense(128, activation='relu'))
model.add(Dense(128, activation='relu'))
model.add(Dense(147456/3, activation='softplus'))
modelOpt=tf.optimizers.RMSprop(learning_rate=ALPHA)
model.compile(optimizer=modelOpt, loss=keras.losses.mean_absolute_error ,metrics=['accuracy'])
model.fit(dataset, epochs=EPOCHS)
result = model.predict(input_images)

#result = result*255

result=result.reshape((192,256))
#print(result)
img = Image.fromarray(result, 'L')
img.show()

Source: Python Questions

LEAVE A COMMENT