『JAX中文文档』JAX快速入门

最新的

JAX快速入门


首先解答一个问题:JAX是什么?

简单的说就是GPU加速、支持自动微分(autodiff)的numpy。众所周知,numpy是Python下的基础数值运算库,得到广泛应用。用Python搞科学计算或机器学习,没人离得开它。但是numpy不支持GPU或其他硬件加速器,也没有对backpropagation的内置支持,再加上Python本身的速度限制,所以很少有人会在生产环境下直接用numpy训练或部署深度学习模型。这也是为什么会出现Theano, TensorFlow, Caffe等深度学习框架的原因。但是numpy有其独特的优势:底层、灵活、调试方便、API稳定且为大家所熟悉(与MATLAB一脉相承),深受研究者的青睐。JAX的主要出发点就是将numpy的以上优势与硬件加速结合。现在已经开源的JAX ( https://github.com/google/jax) 就是通过GPU (CUDA)来实现硬件加速。出自:https://www.zhihu.com/question/306496943/answer/557876584

 

小宋说:JAX 其实就是一个支持加速器(GPU 和 TPU)的科学计算库(numpy, scipy)和神经网络库(提供relu,sigmoid, conv 等),相较于PyTorch与TensorFlow更加灵活,通用性更佳。这也是笔者推荐学习和做这个翻译工作的原因,带着大家一起去学习掌握这个框架。

由于笔者非英语专业,有些内荣难免翻译有误,欢迎大家批评指正。对于有些笔者不确定的翻译,采用下划线加括号引用原词的方式来补充,例如:自动微分differentiation

 

官方定义:JAX是CPU,GPU和TPU上的NumPy,具有出色的自动差分differentiation),可用于高性能机器学习研究。

 

作为更新版本的Autograd,JAX可以自动微分本机Python和NumPy代码。它可以通过Python的大部分功能(包括循环,if,递归和闭包)进行微分,甚至可以采用派生类的派生类。它支持反向模式和正向模式微分,并且两者可以任意顺序组成。

新功能是JAX使用 XLA 在诸如GPU和TPU的加速器上编译和运行您的NumPy代码。默认情况下,编译是在后台进行的,而库调用将得到及时的编译和执行。但是,JAX甚至允许您使用单功能API即时将自己的Python函数编译为XLA优化的内核。编译和自动微分可以任意组合,因此您无需离开Python即可表达复杂的算法并获得最佳性能。

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

乘法矩阵

 

在以下示例中,我们将生成随机数据。NumPy和JAX之间的一大区别是生成随机数的方式。有关更多详细信息,请参见JAX中的Common Gotchas

key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)
[-0.372111    0.2642311  -0.18252774 -0.7368198  -0.44030386 -0.15214427
 -0.6713536  -0.59086424  0.73168874  0.56730247]

乘以两个大矩阵。

size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU
489 ms ± 3.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

我们补充说,block_until_ready因为默认情况下JAX使用异步执行(请参见异步调度)。

JAX NumPy函数可在常规NumPy数组上使用。

import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()
488 ms ± 942 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

这样比较慢,因为它每次都必须将数据传输到GPU。您可以使用来确保NDArray由设备内存支持device_put()

from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()
487 ms ± 9.94 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

的输出device_put()仍然像NDArray一样,但是它仅在需要打印,绘图,保存printing, plotting, saving)到磁盘,分支等需要它们的值时才将值复制回CPU。的行为device_put()等效于函数,但是速度更快。jit(lambda x: x)

如果您有GPU(或TPU!),这些调用将在加速器上运行,并且可能比在CPU上快得多。

x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)
235 ms ± 546 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

JAX不仅仅是一个由GPU支持的NumPy。它还带有一些程序转换,这些转换在编写数字代码时很有用。目前,主要有三个:

  • jit(),以加快您的代码

  • grad(),用于求梯度derivatives

  • vmap(),用于自动矢量化或批处理。

让我们一一介绍。我们还将最终以有趣的方式编写这些内容。

 

利用jit()加快功能

JAX在GPU上透明运行(如果没有,则在CPU上运行,而TPU即将推出!)。但是,在上面的示例中,JAX一次将内核分配给GPU一次操作。如果我们有一系列操作,则可以使用@jit装饰器使用XLA一起编译多个操作。让我们尝试一下。

def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()
4.4 ms ± 107 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

我们可以使用加快速度@jit,它将在第一次selu调用jit-compile并将其之后缓存。

selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()
860 µs ± 27.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

通过 grad()计算梯度

除了评估数值函数外,我们还希望对其进行转换。一种转变是自动微分。在JAX中,就像在Autograd中一样,您可以使用grad()函数来计算梯度。

def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
[0.25       0.19661197 0.10499357]

让我们以极限微分finite differences)验证我们的结果是正确的。

def first_finite_differences(f, x):
  eps = 1e-3
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])


print(first_finite_differences(sum_logistic, x_small))
[0.24998187 0.1964569  0.10502338]

求解梯度可以通过简单调用grad()grad()jit()可以任意混合。在上面的示例中,我们先抖动sum_logistic然后取其派生词。我们继续深入学习实验:

print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
-0.035325594

对于更高级的autodiff,可以将其jax.vjp()用于反向模式矢量雅各比积和jax.jvp()正向模式雅可比矢量积。两者可以彼此任意组合,也可以与其他JAX转换任意组合。这是组合它们以构成有效计算完整的Hessian矩阵的函数的一种方法:

from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

自动向量化 vmap()

JAX在其API中还有另一种转换,您可能会发现它有用:vmap()向量化映射。它具有沿数组轴映射函数的熟悉语义 familiar semantics),但不是将循环保留在外部,而是将循环推入函数的原始操作中以提高性能。当与组合时jit(),它的速度可以与手动添加批处理尺寸一样快。

我们将使用一个简单的示例,并使用将矩阵向量乘积提升为矩阵矩阵乘积vmap()。尽管在这种特定情况下很容易手动完成此操作,但是相同的技术可以应用于更复杂的功能。

mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)

给定诸如之类的功能apply_matrix,我们可以在Python中循环执行批处理维度,但是这样做的性能通常很差。

def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
Naively batched
4.43 ms ± 9.91 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

我们知道如何手动批处理此操作。在这种情况下,jnp.dot透明地处理额外的批次尺寸。

@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()
Manually batched
51.9 µs ± 1.72 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

但是,假设没有批处理支持,我们的功能更加复杂。我们可以用来vmap()自动添加批处理支持。

@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Auto-vectorized with vmap
79.7 µs ± 249 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

当然,vmap()可以与任意组成jit()grad()和任何其它JAX变换。

这只是JAX可以做的事情。我们很高兴看到您的操作!

小宋是呢 CSDN认证博客专家 AI工程师 深度学习领域专家
作者简介:深度学习开发分享博主。全网粉丝3W+,阅读量200W+。
CSDN深度学习博客专家以及微信公众号《简明AI》主要作者。创作内容是基于深度学习的理论学习与应用开发技术分享,致力于最简单明了AI技术分享与最实用AI应用教程。

撰写并发表深度学习论文两篇,获得国家级及省级一等奖奖项八次,以第一作者授权实用新型及发明专利共计十余项,天池与BDCI比赛Top10奖项数次。

在某公司担任算法工程师,从事计算机视觉及时序序列数据的检测识别;深度学习工程化经验丰富,擅长针对新算法研究与应用,包括对模型调优、模型转化及多平台部署等。
Jetty 欢迎访问Jetty文档 Wiki. Jetty是一个开源项目,提供了http服务器、http客户端和java servlet容器。 这个wiki提供jetty的入门教程、基础配置、功能特性、优化、安全、JavaEE、监控、常见问题、故障排除帮助等等。它包含教程、使用手册、视频、特征描述、参考资料以及常见问题。 Jetty文档 ---------------- 入门: 下载Download, 安装, 配置, 运行 Jetty入门(视频) 下载和安装Jetty 如何安装一个Jetty包 如何配置Jetty – 主要文档 如何运行Jetty 用JConsole监控Jetty 如何使用Jetty开发 Jetty HelloWorld教程 Jetty和Maven HelloWorld教程 Jetty(6)入门 (www.itjungle.com) Jetty Start.jar 配置Jetty 如何设置上下文(Context Path) 如何知道使用了那些jar包 如何配置SSL 如何使用非root用户监听80端口 如何配置连接器(Connectors) 如何配置虚拟主机(Virtual Hosts) 如何配置会话ID(Session IDs) 如何序列化会话(Session) 如何重定向或移动应用(Context) 如何让一个应用响应一个特定端口 使用JNDI 使用JNDI 在JNDI中配置数据源(DataSource) 内嵌Jetty服务器 内嵌Jetty教程 内嵌Jetty的HelloWorld教程 内嵌Jetty视频 优化Jetty 如何配置垃圾收集 如何配置以支持高负载 在Jetty中部署应用 部署管理器 部署绑定 热部署 Context提供者 如何部署web应用 webApp提供者 如何部署第三方产品 部署展开形式的web应用 使用Jetty进行开发 如何使用Jetty进行开发 如何编写Jetty中的Handlers 使用构建工具 如何在Maven中使用Jetty 如何在Ant中使用Jetty Maven和Ant的更多支持 Jetty Maven插件(Plugin) Jetty Jspc Maven插件(Plugin) Maven web应用工程原型 Ant Jetty插件(Plugin) 使用集成开发环境(IDEs) 在Eclipse中使用Jetty 在IntelliJ中使用Jetty 在Eclipse中工作 在Eclipse中开发Jetty Jetty WTP插件(Plugin) JettyOSGi SDK for Eclipse-PDE EclipseRT Jetty StarterKit SDK OSGi Jetty on OSGi, RFC66 基于Jetty OSGi的产品 OSGi贴士 Equinox中使用Jetty实现HTTP Service Felix中使用Jetty实现HTTP Service PAX中使用Jetty实现HTTP Srevice ProSyst mBedded Server Equinox Edition Spring Dynamic Modules里的Jetty JOnAS5里的Jetty 配置Ajax、Comet和异步Servlets 持续和异步Servlets 100 Continue和102 Processing WebSocket Servlet 异步的REST Stress Testing CometD 使用Servlets和Filters Jetty中绑定的Servlets Quality of Service Filter Cross Origin Filter 配置安全策略(Security Policies) 安全领域(Security Realms) 安全域配置教程 Java Authentication and Authorization Service (JAAS) JAAS配置教程 JASPI 安全模型(Secure Mode) 存储在文件中的安全密码以及编程教程 如何开启或禁止Jetty中的SSL功能 如何在Jetty中安全存储密码 如何安全终止Jetty 如何配置Spnego Application Server Integrations(集成) Apache Geronimo JEE 配置Apache httpd和Jetty教程 配置Apache mod_proxy和Jetty 配置Jetty中的AJP13 在JBoss中配置Jetty Remote Glassfish EJBs from Jetty Jetty and Spring EJB3 (Pitchfork) JBoss EJB3 ObjectWeb EasyBeans
相关推荐
spring boot中文文档,从安装到部署。 I. Spring Boot文件 1.关于文档 2.获得帮助 3.第一步 4.使用Spring Boot 5.了解Spring Boot功能 6.转向生产 7.高级主题 II。入门 8.介绍Spring Boot 9.系统要求 9.1.Servlet容器 10.安装Spring Boot 10.1.Java Developer的安装说明 10.1.1.Maven安装 10.1.2.Gradle安装 10.2.安装Spring Boot CLI 10.2.1.手动安装 10.2.2.使用SDKMAN安装! 10.2.3.OSX Homebrew安装 10.2.4.MacPorts安装 10.2.5.命令行完成 10.2.6.Windows Scoop安装 10.2.7.快速启动Spring CLI示例 10.3.从早期版本的Spring Boot升级 11.开发您的第一个Spring Boot应用程序 11.1.创建POM 11.2.添加Classpath依赖项 11.3.编写代码 11.3.1.@RestController和@RequestMapping Annotations 11.3.2.@EnableAutoConfiguration注释 11.3.3.“主要”方法 11.4.运行示例 11.5.创建一个可执行的Jar 12.接下来要阅读的内容 III。使用Spring Boot 13.构建系统 13.1.依赖管理 13.2.Maven 13.2.1.继承Starter Parent 13.2.2.在没有父POM的情况下使用Spring Boot 13.2.3.使用Spring Boot Maven插件 13.3.Gradle 13.4.Ant 13.5.Starters 14.构建您的代码 14.1.使用“默认”包 14.2.找到主应用程序类 15.配置类 15.1.导入其他配置类 15.2.导入XML配置 16.自动配置 16.1.逐步更换自动配置 16.2.禁用特定的自动配置类 17. Spring Beans和依赖注入 18.使用@SpringBootApplication Annotation 19.运行您的应用程序 19.1.从IDE运行 19.2.作为打包应用程序运行 19.3.使用Maven插件 19.4.使用Gradle插件 19.5.热插拔 20.开发人员工具 20.1.Property默认值 20.2.自动重启 20.2.1.记录条件评估中的更改 20.2.2.不包括资源 20.2.3.观看其他路径 20.2.4.禁用重启 20.2.5.使用触发器文件 20.2.6.自定义重新启动类加载器 20.2.7.已知限制 20.3.LiveReload 20.4.全局设置 20.5.远程应用 20.5.1.运行远程客户端应用程序 20.5.2.远程更新 21.包装您的生产应用程序 22.接下来要阅读的内容 IV。Spring Boot功能 23. SpringApplication 23.1.启动失败 23.2.自定义横幅 23.3.自定义SpringApplication 23.4.Fluent Builder API 23.5.应用程序事件和监听器 23.6.网络环境 23.7.访问应用程序参数 23.8.使用ApplicationRunner或CommandLineRunner 23.9.申请退出 23.10.管理功能 24.外部配置 24.1.配置随机值 24.2.访问命令行属性 24.3.应用程序Property文件 24.4.配置文件特定的属性 24.5.属性中的占位符 24.6.加密属性 24.7.使用YAML而不是属性 24.7.1.加载YAML 24.7.2.在Spring环境中将YAML公开为属性 24.7.3.多个档案的YAML文件 24.7.4.YAML缺点 24.8.类型安全的配置属性 24.8.1.第三方配置 24.8.2.轻松绑定 24.8.3.合并复杂类型 24.8.4.属性转换 转换持续时间 转换数据大小 24.8.5.@ConfigurationProperties验证 24.8.6.@ConfigurationProperties vs. @Value 25.简介 25.1.添加活动配置文件 25.2.以编程方式设置配置文件 25.3.配置文件特定的配置文件 26.记录 26.1.日志格式 26.2.控制台输出 26.2.1.彩色编码输出 26.3.文件输出 26.4.日志级别 26.5.日志组 26.6.自定义日志配置 26.7.Logback Extensions 26.7.1.特定于配置文件的配置 26.7.2.环境属性 27. JSON 27.1.Jackson 27.2.GSON 27.3.JSON-B 28.开发Web应用程序 28.1.“Spring Web MVC框架” 28.1.1.Spring MVC自动配置 28.1.2.HttpMessageConverters 28.1.3.自定义JSON序列化程序和反序列化程序 28.1.4.MessageCodesResolver的信息 28.1.5.静态内容 28.1.6.欢迎页面 28.1.7.自定义Favicon 28.1.8.路径匹配和内容协商 28.1.9.ConfigurableWebBindingInitializer 28.1.10.模板引擎 28.1.11.错误处理 自定义错误页面 将错误页面映射到Spring MVC之外 28.1.12.Spring HATEOAS 28.1.13.CORS支持 28.2.“Spring WebFlux框架” 28.2.1.Spring WebFlux自动配置 28.2.2.带有HttpMessageReaders和HttpMessageWriters的HTTP编解码器 28.2.3.静态内容 28.2.4.模板引擎 28.2.5.错误处理 自定义错误页面 28.2.6.网络过滤器 28.3.JAX-RS和Jersey 28.4.嵌入式Servlet容器支持 28.4.1.Servlet,过滤器和监听器 注册Servlet,过滤器和监听器Spring Beans 28.4.2.Servlet上下文初始化 扫描Servlet,过滤器和侦听器 28.4.3.ServletWebServerApplicationContext 28.4.4.自定义嵌入式Servlet容器 程序化定制 直接自定义ConfigurableServletWebServerFactory 28.4.5.JSP限制 28.5.嵌入式Reactive Server支持 28.6.Reactive Server资源配置 29.安全 29.1.MVC安全 29.2.WebFlux安全 29.3.OAuth2 29.3.1.客户 OAuth2共同提供者的客户注册 29.3.2.资源服务器 29.3.3.授权服务器 29.4.执行器安全 29.4.1.跨站点请求伪造保护 30.使用SQL数据库 30.1.配置DataSource 30.1.1.嵌入式数据库支持 30.1.2.连接到生产数据库 30.1.3.连接到JNDI数据源 30.2.使用JdbcTemplate 30.3.JPA和Spring Data JPA 30.3.1.实体类 30.3.2.Spring数据JPA存储库 30.3.3.创建和删除JPA数据库 30.3.4.在View中打开EntityM
©️2020 CSDN 皮肤主题: 博客之星2020 设计师:CY__ 返回首页
实付 69.90元
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值