Custom Loss Function in Tensorflow not working

  python, tensorflow, tensorflow2.0

I have built a model using custom image generator and custom loss function in tensorflow2.
The batch_size is 62 and y_pred is a 15 dimensional vector.
I have checked the output of the generator function and Its working fine. Further, I checked loss function

def loss_func(y_true, y_pred):
    print('ny_pred_shape', y_pred.shape)
    print('y_pred',y_pred)
    print('y_true_shape', y_true.shape)
    print('y_true',y_true)
    loc_loss = tf.keras.losses.binary_crossentropy(y_true[:, :4], y_pred[:, :4])
    cls_loss = tf.keras.losses.categorical_crossentropy(y_true[4:, :-1], y_pred[4:, :-1])
    obj_loss = tf.keras.losses.binary_crossentropy(y_true[-1], y_pred[-1])
    loss = loc_loss * y_true[-1] + cls_loss * y_true[-1] + obj_loss / 2
    print('loss',loss)
    return loss

This ran twice when called by model.fit()

y_pred_shape (None, 15)
y_pred Tensor("functional_9/concatenate_4/concat:0", shape=(None, 15), dtype=float32)
y_true_shape (None, None)
y_true Tensor("IteratorGetNext:1", shape=(None, None), dtype=float32)
loss Tensor("loss_func/add_7:0", shape=(None,), dtype=float32)

y_pred_shape (None, 15)
y_pred Tensor("functional_9/concatenate_4/concat:0", shape=(None, 15), dtype=float32)
y_true_shape (None, None)
y_true Tensor("IteratorGetNext:1", shape=(None, None), dtype=float32)
loss Tensor("loss_func/add_7:0", shape=(None,), dtype=float32)

and finally gave this

InvalidArgumentError:  Incompatible shapes: [62] vs. [15]
     [[node loss_func/mul_5 (defined at <ipython-input-8-eadd25fe89e7>:36) ]] [Op:__inference_train_function_8568]

Errors may have originated from an input operation.
Input Source operations connected to node loss_func/mul_5:
 loss_func/Mean (defined at <ipython-input-8-eadd25fe89e7>:33)

Function call stack:
train_function

To test the functioning of the loss fucntion. I did this

y = tf.random.uniform((62,15))
loss_func(y,y)

This gave me

InvalidArgumentError: Incompatible shapes: [62] vs. [15] [Op:Mul]

Can anyone please tell me where is the mistake and what is the fix.

Source: Python Questions

LEAVE A COMMENT