[JAX] JAX: High-Performance Array Computing

JAsmine_log·2024년 8월 3일
0

JAX

Concept

수학적 함수를 python에서 활용하기 위해 사용하는 머신러닝 프레임워크다. Autograd와 Tensorflow의 XLA의 수정된 버전이라고 볼 수 있다. 즉, 이 파이썬 라이브러리는 array 연산과 프로그램 변형을 가속하여 고성능으로 수치계산과 대형(larges-scale) ML이 가능하도록 설계되었다.

A machine learning framework for transforming numerical funcitons to be used in python. It is described as bringing together a modified version of autograd and Tensorflow's XLA. by_wikipedia

A Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning. by_https://jax.readthedocs.io/en/latest/

Installation

  • for CPU on Linuxm, Winods, macOS
    pip install jax

  • for NVIDIA GPU
    pip install -U jax[cuda12]

Example

01 Numpy

import jax.numpy as jnp

def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(5.0)
print(selu(x))

02 Taking derivatives with jax.grad()

from jax import grad

def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

03 Expecting bugs and sharp edges

import jax.numpy as jnp
from jax import grad, jit, vmap

def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.tanh(outputs)  # inputs to the next layer
  return outputs                # no activation on last layer

def loss(params, inputs, targets):
  preds = predict(params, inputs)
  return jnp.sum((preds - targets)**2)

grad_loss = jit(grad(loss))  # compiled gradient evaluation function
perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0)))  # fast per-example grads

Reference
[1] https://jax.readthedocs.io/en/latest/
[2] https://en.wikipedia.org/wiki/Google_JAX
[3] https://github.com/google/jax

profile
Everyday Research & Development

0개의 댓글