Skip to content

Latest commit

 

History

History
6 lines (5 loc) · 572 Bytes

README.md

File metadata and controls

6 lines (5 loc) · 572 Bytes

Jax-logistic-regression

Logistic regression classifier using Google's JAX to support GPU acceleration.

This class is an update of a logistic regression class used in my intro to machine learning course. The major difference is the handling of the gradient descent operations, which were rewritten using jax's grad, jit, and vmap functions. The goal with this project is speed - I've found that using JaxReg with GPU acceleration gives a ~29x speed increase over the original class. I used Google colab's free GPU when measuring speed increase (see 'Time Comparison').