Artificial Intelligence
null-img

Multi Layer Perceptron in JAX

This project provides a minimal implementation of a Multi-Layer Perceptron (MLP) classifier built from scratch in JAX and trained on the Iris dataset. It demonstrates the core components of a supervised learning workflow, including model definition, forward pass, loss computation, gradient-based optimization, and accuracy evaluation, all using JAX’s functional and composable API.


Used technologies

image - JAX
image - Numpy
image - Matplotlib
image - SK Learn

Reach the project