Facebook
From gajender, 2 Months ago, written in Plain Text.
Embed
Download Paste or View Raw
Hits: 318
  1. # create and compile the model
  2. model = DiffusionModel(image_size, widths, block_depth)
  3. # below tensorflow 2.9:
  4. # pip install tensorflow_addons
  5. # import tensorflow_addons as tfa
  6. # optimizer=tfa.optimizers.AdamW
  7. model.compile(
  8.     optimizer=keras.optimizers.experimental.AdamW(
  9.         learning_rate=learning_rate, weight_decay=weight_decay
  10.     ),
  11.     loss=keras.losses.mean_absolute_error,
  12. )
  13. # pixelwise mean absolute error is used as loss
  14.  
  15. # save the best model based on the validation KID metric
  16. checkpoint_path = "checkpoints/diffusion_model"
  17. checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
  18.     filepath=checkpoint_path,
  19.     save_weights_only=True,
  20.     monitor="val_kid",
  21.     mode="min",
  22.     save_best_only=True,
  23. )
  24.  
  25. # calculate mean and variance of training dataset for normalization
  26. model.normalizer.adapt(train_dataset)
  27.  
  28. # run training and plot generated images periodically
  29. model.fit(
  30.     train_dataset,
  31.     epochs=num_epochs,
  32.     validation_data=val_dataset,
  33.     callbacks=[
  34.         keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images),
  35.         checkpoint_callback,
  36.     ],
  37. )