diff --git a/train/models.py b/train/models.py index 7a738aa..1532565 100644 --- a/train/models.py +++ b/train/models.py @@ -4,6 +4,7 @@ from preprocessing.nn_dataset import position_encoding_init from train.external import VariationalDropout from train.transformer import TransformerDecoder, TransformerDecoderLayer, TransformerEncoder, TransformerEncoderLayer +from huggingface_hub import PyTorchModelHubMixin """ File containing classes representing various Neural architectures @@ -11,7 +12,13 @@ # MARK:- TonicNet -class TonicNet(nn.Module): +class TonicNet( + nn.Module, + PyTorchModelHubMixin, + library_name="tonicnet", + repo_url="https://github.com/omarperacha/TonicNet", + tags=["polyphonic-music"], +): def __init__(self, nb_tags, nb_layers=1, z_dim =0, nb_rnn_units=100, batch_size=1, seq_len=1, dropout=0.0): super(TonicNet, self).__init__()