Come stampare il numero di parametri del modello in PyTorch

Come Stampare Il Numero Di Parametri Del Modello In Pytorch



PyTorch è un framework popolare utilizzato nel deep learning. Offre molteplici funzionalità per la generazione di reti neurali complesse (NN). Gli utenti possono eseguire operazioni di training del modello con questo framework. Tuttavia, gli utenti devono avere familiarità con il numero di parametri prima di addestrare il modello.

Questo blog descriverà:

Quali sono i parametri in PyTorch?

In PyTorch, il ' nn.Modulo La classe ' viene utilizzata per definire i modelli. Comprende tutte le operazioni e i livelli che compongono il modello. Ogni livello contiene una serie di parametri. I parametri vengono sostanzialmente aggiornati durante l'addestramento per ridurre al minimo l'errore tra i valori effettivi del modello e le previsioni.







Perché gli utenti devono verificare i parametri del modello?

Durante l'addestramento del modello, gli utenti devono conoscere il numero di parametri del proprio modello perché richiede molta memoria e potenza di elaborazione. Se hanno familiarità con il numero di parametri del modello, possono facilmente valutare la quantità di memoria che sarà richiesta e quanto tempo ci vorrà per l'addestramento, il che aiuta gli utenti a ottimizzare il processo di addestramento e a evitare che il sistema rimanga senza risorse. spazio.



Come visualizzare il numero di parametri del modello in PyTorch?

IL ' nn.Modulo La classe ' ha il ' parametri() ' Metodo utilizzato per visualizzare il numero di parametri del modello nel modello PyTorch. Per ottenere tutti gli elementi, il ' numero1() viene utilizzato il metodo '.



Per comprendere il concetto discusso in precedenza, diamo un'occhiata al codice fornito:





importare torcia. nn COME nn

classe NNModel ( nn. Modulo ) :
def __caldo__ ( se stesso ) :
super ( NNModel , se stesso ) . __caldo__ ( )
se stesso . fc1 = nn. Lineare ( 10 , cinquanta )
se stesso . fc2 = nn. Lineare ( cinquanta , 1 )

def inoltrare ( se stesso , io ) :
io = se stesso . fc1 ( io )
io = se stesso . fc2 ( io )
ritorno io

mio_modello = NNModel ( )
t_params = somma ( P. dare un nome ( ) per P In mio_modello. parametri ( ) )
stampa ( F 'Numero totale di parametri: {t_params}' )

Nel codice sopra indicato:

  • Innanzitutto, definiamo un modello che ha due strati lineari.
  • Quindi, genera l'istanza del modello e utilizza il comando ' parametri() ' metodo per recuperare tutti i parametri.
  • Successivamente, applichiamo l'espressione del generatore per calcolare tutti i parametri sommando il numero di elementi di ciascun parametro.
  • Infine, chiama il ' stampa() ' istruzione per visualizzare i valori risultanti sullo schermo:



Nel codice sopra descritto, abbiamo visualizzato solo il numero totale di parametri, se si desidera ottenere il nome e la dimensione del parametro, è possibile utilizzare le seguenti righe di codice:

per nome , param In mio_modello. stato_dict ( ) . elementi ( ) :

stampa ( nome , param. misurare ( ) )

Qui:

  • state_dict() ' è l'oggetto dizionario Python utilizzato per archiviare e caricare modelli da PyTorch.
  • articolo() 'Il metodo viene utilizzato per restituire l'elenco con tutte le chiavi del dizionario insieme ai valori.
  • stampa() L'istruzione ' viene utilizzata per stampare il nome e la dimensione del parametro passando il comando ' misurare() 'metodo e parametro:

È tutto! Abbiamo compilato il modo più semplice per stampare il numero di parametri del modello in PyTorch.

Conclusione

In PyTorch, il ' nn.Modulo La classe ' viene utilizzata per definire i modelli che includono tutte le operazioni e i livelli che compongono il modello. IL ' nn.Modulo La classe ' ha il ' parametri() ' Metodo utilizzato per visualizzare il numero di parametri del modello nel modello PyTorch. Questo articolo ha dimostrato il metodo per stampare il numero di parametri del modello in PyTorch.