수학적 함수를 python에서 활용하기 위해 사용하는 머신러닝 프레임워크다. Autograd와 Tensorflow의 XLA의 수정된 버전이라고 볼 수 있다. 즉, 이 파이썬 라이브러리는 array 연산과 프로그램 변형을 가속하여 고성능으로 수치계산과 대형(larges-scale) ML이 가능하도록 설계되었다.
A machine learning framework for transforming numerical funcitons to be used in python. It is described as bringing together a modified version of autograd and Tensorflow's XLA. by_wikipedia
A Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning. by_https://jax.readthedocs.io/en/latest/
for CPU on Linuxm, Winods, macOS
pip install jax
for NVIDIA GPU
pip install -U jax[cuda12]
import jax.numpy as jnp
def selu(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = jnp.arange(5.0)
print(selu(x))
from jax import grad
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
import jax.numpy as jnp
from jax import grad, jit, vmap
def predict(params, inputs):
for W, b in params:
outputs = jnp.dot(inputs, W) + b
inputs = jnp.tanh(outputs) # inputs to the next layer
return outputs # no activation on last layer
def loss(params, inputs, targets):
preds = predict(params, inputs)
return jnp.sum((preds - targets)**2)
grad_loss = jit(grad(loss)) # compiled gradient evaluation function
perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example grads
Reference
[1] https://jax.readthedocs.io/en/latest/
[2] https://en.wikipedia.org/wiki/Google_JAX
[3] https://github.com/google/jax