encoded_imgs = encoder.predict(x_test) decoded_imgs = decoder.predict(encoded_imgs) n = 10 plt.figure(figsize=(20, 6)) for i in range(n): # display original ax = plt.subplot(3, n, i + 1) plt.imshow(x_test[i].reshape(28, 28)) plt.gray() ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) # display reconstruction ax = plt.subplot(3, n, i + 1 + n) plt.imshow(encoded_imgs[i].reshape(6, 6)) plt.gray() ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) # display reconstruction ax = plt.subplot(3, n, i + 1 + 2*n) plt.imshow(decoded_imgs[i].reshape(28, 28)) plt.gray() ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) plt.show()