Τι είναι το Google JAX; Όλα όσα πρέπει να γνωρίζετε

Το Google JAX ή Just After Execution είναι ένα πλαίσιο που αναπτύχθηκε από την Google για να επιταχύνει τις εργασίες μηχανικής εκμάθησης.

Μπορείτε να το θεωρήσετε μια βιβλιοθήκη για την Python, η οποία βοηθά στην ταχύτερη εκτέλεση εργασιών, επιστημονικούς υπολογισμούς, μετασχηματισμούς συναρτήσεων, βαθιά μάθηση, νευρωνικά δίκτυα και πολλά άλλα.

Σχετικά με το Google JAX

Το πιο θεμελιώδες πακέτο υπολογισμού στην Python είναι το πακέτο NumPy το οποίο έχει όλες τις λειτουργίες όπως συναθροίσεις, διανυσματικές πράξεις, γραμμική άλγεβρα, χειρισμούς n-διαστάσεων πίνακα και μήτρας και πολλές άλλες προηγμένες συναρτήσεις.

Τι θα γινόταν αν μπορούσαμε να επιταχύνουμε περαιτέρω τους υπολογισμούς που εκτελούνται χρησιμοποιώντας το NumPy – ιδιαίτερα για τεράστια σύνολα δεδομένων;

Έχουμε κάτι που θα μπορούσε να λειτουργήσει εξίσου καλά σε διαφορετικούς τύπους επεξεργαστών όπως GPU ή TPU, χωρίς αλλαγές στον κώδικα;

Τι θα λέγατε εάν το σύστημα μπορούσε να εκτελέσει μετασχηματισμούς συναρτήσεων αυτόματα και πιο αποτελεσματικά;

Το Google JAX είναι μια βιβλιοθήκη (ή πλαίσιο, όπως λέει η Wikipedia) που κάνει ακριβώς αυτό και ίσως πολλά περισσότερα. Κατασκευάστηκε για τη βελτιστοποίηση της απόδοσης και την αποτελεσματική εκτέλεση εργασιών μηχανικής μάθησης (ML) και βαθιάς μάθησης. Το Google JAX παρέχει τις ακόλουθες δυνατότητες μετασχηματισμού που το καθιστούν μοναδικό από άλλες βιβλιοθήκες ML και βοηθούν στον προηγμένο επιστημονικό υπολογισμό για βαθιά μάθηση και νευρωνικά δίκτυα:

  • Αυτόματη διαφοροποίηση
  • Αυτόματη διανυσματοποίηση
  • Αυτόματη παραλληλοποίηση
  • Σύνταξη Just-in-time (JIT).

Τα μοναδικά χαρακτηριστικά του Google JAX

Όλοι οι μετασχηματισμοί χρησιμοποιούν XLA (Accelerated Linear Algebra) για υψηλότερη απόδοση και βελτιστοποίηση μνήμης. Το XLA είναι μια μηχανή μεταγλώττισης βελτιστοποίησης για συγκεκριμένο τομέα που εκτελεί γραμμική άλγεβρα και επιταχύνει μοντέλα TensorFlow. Η χρήση XLA πάνω από τον κώδικα Python σας δεν απαιτεί σημαντικές αλλαγές στον κώδικα!

Ας εξερευνήσουμε λεπτομερώς καθένα από αυτά τα χαρακτηριστικά.

Χαρακτηριστικά του Google JAX

Το Google JAX συνοδεύεται από σημαντικές συνθετικές συναρτήσεις μετασχηματισμού για τη βελτίωση της απόδοσης και την αποτελεσματικότερη εκτέλεση εργασιών βαθιάς εκμάθησης. Για παράδειγμα, αυτόματη διαφοροποίηση για να πάρετε τη διαβάθμιση μιας συνάρτησης και να βρείτε παραγώγους οποιασδήποτε τάξης. Ομοίως, αυτόματη παραλληλοποίηση και JIT για την εκτέλεση πολλαπλών εργασιών παράλληλα. Αυτοί οι μετασχηματισμοί είναι βασικοί σε εφαρμογές όπως η ρομποτική, τα παιχνίδια, ακόμη και η έρευνα.

Μια συνάρτηση μετασχηματισμού με δυνατότητα σύνθεσης είναι μια καθαρή συνάρτηση που μετατρέπει ένα σύνολο δεδομένων σε άλλη μορφή. Ονομάζονται composable καθώς είναι αυτόνομες (δηλαδή, αυτές οι συναρτήσεις δεν έχουν εξαρτήσεις με το υπόλοιπο πρόγραμμα) και είναι χωρίς κατάσταση (δηλαδή, η ίδια είσοδος θα έχει πάντα ως αποτέλεσμα την ίδια έξοδο).

Y(x) = T: (f(x))

Στην παραπάνω εξίσωση, f(x) είναι η αρχική συνάρτηση στην οποία εφαρμόζεται ένας μετασχηματισμός. Το Y(x) είναι η συνάρτηση που προκύπτει μετά την εφαρμογή του μετασχηματισμού.

Για παράδειγμα, εάν έχετε μια συνάρτηση με το όνομα “total_bill_amt” και θέλετε το αποτέλεσμα ως μετασχηματισμό συνάρτησης, μπορείτε απλά να χρησιμοποιήσετε τον μετασχηματισμό που θέλετε, ας πούμε gradient (grad):

  Πώς να αποφασίσετε ποιος διακόπτης Nintendo είναι κατάλληλος για εσάς

grad_total_bill = grad(total_bill_amt)

Μετασχηματίζοντας αριθμητικές συναρτήσεις χρησιμοποιώντας συναρτήσεις όπως η grad(), μπορούμε εύκολα να πάρουμε τις παραγώγους υψηλότερης τάξης τους, τις οποίες μπορούμε να χρησιμοποιήσουμε εκτενώς σε αλγόριθμους βελτιστοποίησης βαθιάς μάθησης, όπως το gradient descent, κάνοντας έτσι τους αλγόριθμους ταχύτερους και αποτελεσματικότερους. Ομοίως, χρησιμοποιώντας το jit(), μπορούμε να μεταγλωττίσουμε προγράμματα Python just-in-time (lazily).

#1. Αυτόματη διαφοροποίηση

Η Python χρησιμοποιεί τη συνάρτηση autograd για να διαφοροποιήσει αυτόματα τον NumPy και τον εγγενή κώδικα Python. Το JAX χρησιμοποιεί μια τροποποιημένη έκδοση του autograd (δηλαδή, grad) και συνδυάζει το XLA (Accelerated Linear Algebra) για την εκτέλεση αυτόματης διαφοροποίησης και την εύρεση παραγώγων οποιασδήποτε τάξης για GPU (Μονάδες Επεξεργασίας Γραφικών) και TPU (Μονάδες Επεξεργασίας Τενυστήρα).]

Γρήγορη σημείωση για TPU, GPU και CPU: Η CPU ή η κεντρική μονάδα επεξεργασίας διαχειρίζεται όλες τις λειτουργίες στον υπολογιστή. Η GPU είναι ένας πρόσθετος επεξεργαστής που ενισχύει την υπολογιστική ισχύ και εκτελεί λειτουργίες προηγμένης τεχνολογίας. Το TPU είναι μια ισχυρή μονάδα που αναπτύχθηκε ειδικά για πολύπλοκους και βαρείς φόρτους εργασίας όπως AI και αλγόριθμους βαθιάς μάθησης.

Στην ίδια γραμμή με τη συνάρτηση autograd, η οποία μπορεί να διαφοροποιηθεί μέσω βρόχων, αναδρομών, διακλαδώσεων και ούτω καθεξής, το JAX χρησιμοποιεί τη συνάρτηση grad() για διαβαθμίσεις αντίστροφης λειτουργίας (backpropagation). Επίσης, μπορούμε να διαφοροποιήσουμε μια συνάρτηση σε οποιαδήποτε σειρά χρησιμοποιώντας grad:

grad(grad(grad(sin θ))) (1.0)

Αυτόματη διαφοροποίηση ανώτερης τάξης

Όπως αναφέραμε προηγουμένως, το grad είναι αρκετά χρήσιμο για την εύρεση των μερικών παραγώγων μιας συνάρτησης. Μπορούμε να χρησιμοποιήσουμε μια μερική παράγωγο για να υπολογίσουμε τη βαθμιδωτή κάθοδο μιας συνάρτησης κόστους σε σχέση με τις παραμέτρους του νευρωνικού δικτύου στη βαθιά μάθηση για να ελαχιστοποιήσουμε τις απώλειες.

Υπολογισμός μερικής παραγώγου

Ας υποθέσουμε ότι μια συνάρτηση έχει πολλές μεταβλητές, x, y και z. Η εύρεση της παραγώγου μιας μεταβλητής διατηρώντας τις άλλες μεταβλητές σταθερές ονομάζεται μερική παράγωγος. Ας υποθέσουμε ότι έχουμε μια συνάρτηση,

f(x,y,z) = x + 2y + z2

Παράδειγμα για την εμφάνιση μερικής παραγώγου

Η μερική παράγωγος του x θα είναι ∂f/∂x, που μας λέει πώς αλλάζει μια συνάρτηση για μια μεταβλητή όταν οι άλλες είναι σταθερές. Εάν το εκτελέσουμε χειροκίνητα, πρέπει να γράψουμε ένα πρόγραμμα για να διαφοροποιήσουμε, να το εφαρμόσουμε για κάθε μεταβλητή και στη συνέχεια να υπολογίσουμε την κατάβαση κλίσης. Αυτό θα γίνει μια πολύπλοκη και χρονοβόρα υπόθεση για πολλαπλές μεταβλητές.

Η αυτόματη διαφοροποίηση αναλύει τη συνάρτηση σε ένα σύνολο στοιχειωδών πράξεων, όπως +, -, *, / ή sin, cos, tan, exp, κ.λπ., και στη συνέχεια εφαρμόζει τον κανόνα της αλυσίδας για τον υπολογισμό της παραγώγου. Μπορούμε να το κάνουμε αυτό τόσο σε λειτουργία εμπρός όσο και σε αντίστροφη λειτουργία.

Δεν είναι αυτό! Όλοι αυτοί οι υπολογισμοί γίνονται τόσο γρήγορα (καλά, σκεφτείτε ένα εκατομμύριο υπολογισμούς παρόμοιους με τους παραπάνω και τον χρόνο που μπορεί να χρειαστεί!). Η XLA φροντίζει για την ταχύτητα και την απόδοση.

#2. Επιταχυνόμενη Γραμμική Άλγεβρα

Ας πάρουμε την προηγούμενη εξίσωση. Χωρίς XLA, ο υπολογισμός θα πάρει τρεις (ή περισσότερους) πυρήνες, όπου κάθε πυρήνας θα εκτελεί μια μικρότερη εργασία. Για παράδειγμα,

  Πώς να ενεργοποιήσετε τις ειδοποιήσεις φλας LED μόνο όταν το iPhone σας είναι αθόρυβο

Πυρήνας k1 –> x * 2y (πολλαπλασιασμός)

k2 –> x * 2y + z (προσθήκη)

k3 –> Μείωση

Εάν η ίδια εργασία εκτελείται από το XLA, ένας μόνος πυρήνας φροντίζει για όλες τις ενδιάμεσες λειτουργίες συντήκοντάς τις. Τα ενδιάμεσα αποτελέσματα των στοιχειωδών λειτουργιών μεταδίδονται σε ροή αντί να αποθηκεύονται στη μνήμη, εξοικονομώντας έτσι τη μνήμη και βελτιώνοντας την ταχύτητα.

#3. Just-in-time συλλογή

Το JAX χρησιμοποιεί εσωτερικά τον μεταγλωττιστή XLA για να αυξήσει την ταχύτητα της εκτέλεσης. Το XLA μπορεί να αυξήσει την ταχύτητα της CPU, της GPU και της TPU. Όλα αυτά είναι δυνατά χρησιμοποιώντας την εκτέλεση κώδικα JIT. Για να το χρησιμοποιήσουμε, μπορούμε να χρησιμοποιήσουμε το jit μέσω εισαγωγής:

from jax import jit
def my_function(x):
	…………some lines of code
my_function_jit = jit(my_function)

Ένας άλλος τρόπος είναι διακοσμώντας το jit πάνω από τον ορισμό της συνάρτησης:

@jit
def my_function(x):
	…………some lines of code

Αυτός ο κώδικας είναι πολύ πιο γρήγορος επειδή ο μετασχηματισμός θα επιστρέψει τη μεταγλωττισμένη έκδοση του κώδικα στον καλούντα αντί να χρησιμοποιήσει τον διερμηνέα Python. Αυτό είναι ιδιαίτερα χρήσιμο για εισόδους διανυσμάτων, όπως πίνακες και πίνακες.

Το ίδιο ισχύει και για όλες τις υπάρχουσες συναρτήσεις python. Για παράδειγμα, συναρτήσεις από το πακέτο NumPy. Σε αυτήν την περίπτωση, θα πρέπει να εισαγάγουμε το jax.numpy ως jnp αντί για το NumPy:

import jax
import jax.numpy as jnp

x = jnp.array([[1,2,3,4], [5,6,7,8]])

Μόλις το κάνετε αυτό, το αντικείμενο του πυρήνα του πίνακα JAX που ονομάζεται DeviceArray αντικαθιστά τον τυπικό πίνακα NumPy. Το DeviceArray είναι τεμπέλικο – οι τιμές διατηρούνται στο γκάζι μέχρι να χρειαστούν. Αυτό σημαίνει επίσης ότι το πρόγραμμα JAX δεν περιμένει τα αποτελέσματα να επιστρέψουν στο πρόγραμμα κλήσης (Python), ακολουθώντας έτσι μια ασύγχρονη αποστολή.

#4. Αυτόματη διανυσματοποίηση (vmap)

Σε έναν τυπικό κόσμο μηχανικής μάθησης, έχουμε σύνολα δεδομένων με ένα εκατομμύριο ή περισσότερα σημεία δεδομένων. Πιθανότατα, θα εκτελούσαμε μερικούς υπολογισμούς ή χειρισμούς σε καθένα ή στα περισσότερα από αυτά τα σημεία δεδομένων – κάτι που είναι πολύ χρονοβόρο και χρονοβόρο έργο! Για παράδειγμα, αν θέλετε να βρείτε το τετράγωνο καθενός από τα σημεία δεδομένων στο σύνολο δεδομένων, το πρώτο πράγμα που θα σκεφτόσασταν είναι να δημιουργήσετε έναν βρόχο και να πάρετε το τετράγωνο ένα προς ένα – argh!

Εάν δημιουργήσουμε αυτά τα σημεία ως διανύσματα, θα μπορούσαμε να κάνουμε όλα τα τετράγωνα με μία κίνηση εκτελώντας χειρισμούς διανυσμάτων ή πινάκων στα σημεία δεδομένων με το αγαπημένο μας NumPy. Και αν το πρόγραμμά σας μπορούσε να το κάνει αυτόματα – μπορείτε να ζητήσετε κάτι περισσότερο; Αυτό ακριβώς κάνει το JAX! Μπορεί να διανύσει αυτόματα όλα τα σημεία δεδομένων σας, ώστε να μπορείτε να εκτελέσετε εύκολα οποιεσδήποτε λειτουργίες σε αυτά – κάνοντας τους αλγόριθμούς σας πολύ πιο γρήγορους και πιο αποτελεσματικούς.

Το JAX χρησιμοποιεί τη λειτουργία vmap για αυτόματη διανυσματοποίηση. Σκεφτείτε τον ακόλουθο πίνακα:

x = jnp.array([1,2,3,4,5,6,7,8,9,10])
y = jnp.square(x)

Κάνοντας μόνο τα παραπάνω, η μέθοδος τετραγώνου θα εκτελεστεί για κάθε σημείο του πίνακα. Αλλά αν κάνετε τα εξής:

vmap(jnp.square(x))

Το τετράγωνο της μεθόδου θα εκτελεστεί μόνο μία φορά επειδή τα σημεία δεδομένων διανυσματοποιούνται αυτόματα χρησιμοποιώντας τη μέθοδο vmap πριν από την εκτέλεση της συνάρτησης και ο βρόχος πιέζεται προς τα κάτω στο στοιχειώδες επίπεδο λειτουργίας – με αποτέλεσμα πολλαπλασιασμό πίνακα αντί για βαθμωτό πολλαπλασιασμό, δίνοντας έτσι καλύτερη απόδοση .

  Πώς να διορθώσετε το σφάλμα "Αυτή η συσκευή δεν μπορεί να ξεκινήσει". (Κωδικός 10)

#5. Προγραμματισμός SPMD (pmap)

SPMD – ή Ο προγραμματισμός πολλαπλών δεδομένων ενός προγράμματος είναι απαραίτητος σε περιβάλλοντα βαθιάς μάθησης – συχνά εφαρμόζετε τις ίδιες λειτουργίες σε διαφορετικά σύνολα δεδομένων που βρίσκονται σε πολλαπλές GPU ή TPU. Το JAX έχει μια λειτουργία που ονομάζεται αντλία, η οποία επιτρέπει τον παράλληλο προγραμματισμό σε πολλαπλές GPU ή οποιονδήποτε επιταχυντή. Όπως το JIT, τα προγράμματα που χρησιμοποιούν pmap θα μεταγλωττίζονται από το XLA και θα εκτελούνται ταυτόχρονα σε όλα τα συστήματα. Αυτή η αυτόματη παραλληλοποίηση λειτουργεί τόσο για μπροστινούς όσο και για ανάστροφους υπολογισμούς.

Πώς λειτουργεί το pmap

Μπορούμε επίσης να εφαρμόσουμε πολλαπλούς μετασχηματισμούς με μία κίνηση με οποιαδήποτε σειρά σε οποιαδήποτε συνάρτηση όπως:

pmap(vmap(jit(grad (f(x)))))

Πολλαπλοί συνθετικοί μετασχηματισμοί

Περιορισμοί του Google JAX

Οι προγραμματιστές του Google JAX έχουν σκεφτεί καλά να επιταχύνουν τους αλγόριθμους βαθιάς εκμάθησης, ενώ εισάγουν όλους αυτούς τους φοβερούς μετασχηματισμούς. Οι συναρτήσεις και τα πακέτα επιστημονικών υπολογισμών είναι στις γραμμές του NumPy, επομένως δεν χρειάζεται να ανησυχείτε για την καμπύλη εκμάθησης. Ωστόσο, το JAX έχει τους ακόλουθους περιορισμούς:

  • Το Google JAX βρίσκεται ακόμα στα αρχικά στάδια ανάπτυξης και, παρόλο που ο κύριος σκοπός του είναι η βελτιστοποίηση της απόδοσης, δεν προσφέρει πολλά οφέλη για τους υπολογιστές της CPU. Το NumPy φαίνεται να αποδίδει καλύτερα και η χρήση του JAX μπορεί να προσθέσει μόνο τα γενικά έξοδα.
  • Το JAX βρίσκεται ακόμα στην έρευνά του ή στα αρχικά του στάδια και χρειάζεται περισσότερη λεπτομέρεια για να επιτύχει τα πρότυπα υποδομής πλαισίων όπως το TensorFlow, τα οποία είναι πιο καθιερωμένα και έχουν περισσότερα προκαθορισμένα μοντέλα, έργα ανοιχτού κώδικα και εκπαιδευτικό υλικό.
  • Προς το παρόν, το JAX δεν υποστηρίζει λειτουργικό σύστημα Windows – θα χρειαστείτε μια εικονική μηχανή για να λειτουργήσει.
  • Το JAX λειτουργεί μόνο σε καθαρές λειτουργίες – αυτές που δεν έχουν παρενέργειες. Για λειτουργίες με παρενέργειες, το JAX μπορεί να μην είναι καλή επιλογή.

Πώς να εγκαταστήσετε το JAX στο περιβάλλον Python σας

Εάν έχετε εγκατάσταση python στο σύστημά σας και θέλετε να εκτελέσετε το JAX στον τοπικό σας υπολογιστή (CPU), χρησιμοποιήστε τις ακόλουθες εντολές:

pip install --upgrade pip
pip install --upgrade "jax[cpu]"

Εάν θέλετε να εκτελέσετε το Google JAX σε GPU ή TPU, ακολουθήστε τις οδηγίες που δίνονται GitHub JAX σελίδα. Για να ρυθμίσετε την Python, επισκεφτείτε το επίσημες λήψεις python σελίδα.

συμπέρασμα

Το Google JAX είναι εξαιρετικό για τη σύνταξη αποτελεσματικών αλγορίθμων βαθιάς μάθησης, ρομποτικής και έρευνας. Παρά τους περιορισμούς, χρησιμοποιείται εκτενώς με άλλα πλαίσια όπως το Haiku, το Flax και πολλά άλλα. Θα μπορείτε να εκτιμήσετε τι κάνει το JAX όταν εκτελείτε προγράμματα και να δείτε τις διαφορές ώρας στην εκτέλεση κώδικα με και χωρίς JAX. Μπορείτε να ξεκινήσετε διαβάζοντας το επίσημη τεκμηρίωση του Google JAXτο οποίο είναι αρκετά περιεκτικό.