Hi, I'm a beginner, and i found attn = Lambda(lambda x:K.batch_dot(x[0],x[1],axes=[2,2])/self.temper)([q, k]) is equal to attn = Lambda(lambda x:tf.matmul(x[0],x[1],transpose_b=True)/self.temper)([q, k]) Sorry to disturb you.