From 46fde091e675b223749f39e7280e1d004abf0b98 Mon Sep 17 00:00:00 2001 From: not-lain Date: Sat, 24 Aug 2024 22:57:13 +0000 Subject: [PATCH] inherit from PyTorchModelHubMixin --- train/models.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/train/models.py b/train/models.py index 7a738aa..1532565 100644 --- a/train/models.py +++ b/train/models.py @@ -4,6 +4,7 @@ from preprocessing.nn_dataset import position_encoding_init from train.external import VariationalDropout from train.transformer import TransformerDecoder, TransformerDecoderLayer, TransformerEncoder, TransformerEncoderLayer +from huggingface_hub import PyTorchModelHubMixin """ File containing classes representing various Neural architectures @@ -11,7 +12,13 @@ # MARK:- TonicNet -class TonicNet(nn.Module): +class TonicNet( + nn.Module, + PyTorchModelHubMixin, + library_name="tonicnet", + repo_url="https://github.com/omarperacha/TonicNet", + tags=["polyphonic-music"], +): def __init__(self, nb_tags, nb_layers=1, z_dim =0, nb_rnn_units=100, batch_size=1, seq_len=1, dropout=0.0): super(TonicNet, self).__init__()