JAX 中文文档(十三)
原文:jax.readthedocs.io/en/latest/
在 JAX 之上构建
原文:jax.readthedocs.io/en/latest/building_on_jax.html
学习高级 JAX 使用的一种很好的方法是看看其他库如何使用 JAX,它们如何将库集成到其 API 中,它在数学上添加了什么功能,并且如何在其他库中用于计算加速。
以下是 JAX 功能如何用于跨多个领域和软件包定义加速计算的示例。
梯度计算
简单的梯度计算是 JAX 的一个关键特性。在JaxOpt 库中值和 grad 直接用于用户在其源代码中的多个优化算法中。
同样,上面提到的 Dynamax Optax 配对,是过去具有挑战性的梯度使估计方法的一个例子,Optax 的最大似然期望。
在多个设备上单核计算速度加快
在 JAX 中定义的模型然后可以被编译以通过 JIT 编译进行单次计算速度加快。相同的编译码然后可以被发送到 CPU 设备,GPU 或 TPU 设备以获得额外的速度加快,通常不需要额外的更改。 这允许平稳地从开发流程转入生产流程。在 Dynamax 中,线性状态空间模型求解器的计算密集型部分已jitted。 PyTensor 的一个更复杂的例子源于动态地编译 JAX 函数,然后jit 构造的函数。
使用并行化的单台和多台计算机加速
JAX 的另一个好处是使用pmap和vmap函数调用或装饰器轻松并行化计算。在 Dynamax 中,状态空间模型使用VMAP 装饰器进行并行化,其实际用例是多对象跟踪。
将 JAX 代码合并到您的工作流程中或您的用户工作流程中
JAX 非常可组合,并且可以以多种方式使用。 JAX 可以作为独立模式使用,用户自己定义所有计算。 但是其他模式,例如使用构建在 jax 上提供特定功能的库。 这些可以是定义特定类型的模型的库,例如神经网络或状态空间模型或其他,或者提供特定功能,例如优化。以下是每种模式的更具体的示例。
直接使用
Jax 可以直接导入和利用,以便在本网站上“从零开始”构建模型,例如在JAX 教程或使用 JAX 进行神经网络中展示的方法。如果您无法找到特定挑战的预建代码,或者希望减少代码库中的依赖项数量,这可能是最佳选择。
使用 JAX 暴露的可组合领域特定库
另一种常见方法是提供预建功能的包,无论是模型定义还是某种类型的计算。这些包的组合可以混合使用,以实现全面的端到端工作流程,定义模型并估计其参数。
一个例子是Flax,它简化了神经网络的构建。通常将 Flax 与Optax配对使用,其中 Flax 定义了神经网络架构,而 Optax 提供了优化和模型拟合能力。
另一个是Dynamax,它允许轻松定义状态空间模型。使用 Dynamax 可以使用Optax 进行最大似然估计,或者使用Blackjax 进行 MCMC 全贝叶斯后验估计。
用户完全隐藏 JAX
其他库选择完全包装 JAX 以适应其特定 API。例如,PyMC 和Pytensor就是一个例子,用户可能从未直接“看到”JAX,而是使用 PyMC 特定的 API 包装JAX 函数。
注:
原文:jax.readthedocs.io/en/latest/notes.html
本节包含有关使用 JAX 相关主题的简短注释;另请参阅 JAX Enhancement Proposals (JEPs) 中更详细的设计讨论。
依赖和版本兼容性:
API 兼容性概述了 JAX 在不同版本之间 API 兼容性的政策。
Python 和 NumPy 版本支持政策概述了 JAX 与 Python 和 NumPy 的兼容性政策。
迁移和弃用事项:
jax.Array 迁移总结了 jax v 0.4.1 中默认数组类型的更改。
内存和计算使用:
异步调度描述了 JAX 的异步调度模型。
并发性描述了 JAX 与其他 Python 并发性的交互方式。
GPU 内存分配描述了 JAX 在 GPU 内存分配中的交互方式。
程序员保护栏:
等级提升警告描述了如何配置 jax.numpy 以避免隐式等级提升。
API 兼容性
原文:jax.readthedocs.io/en/latest/api_compatibility.html
JAX 不断发展,我们希望能改进其 API。尽管如此,我们希望最大程度减少 JAX 用户社区的混乱,并尽量少做破坏性更改。
JAX 遵循三个月的废弃政策。当对 API 进行不兼容的更改时,我们将尽力遵守以下流程:
更改将在 CHANGELOG.md 中和被废弃 API 的文档字符串中公布,并且旧 API 将发出 DeprecationWarning。
在 jax 发布了废弃 API 后的三个月内,我们可能随时移除已废弃的 API。请注意,三个月是一个较短的时间界限,故意选择快于许多更成熟项目的时间界限。实际上,废弃可能需要更长时间,特别是如果某个功能有很多用户。如果三个月的废弃期变得问题重重,请与我们联系。
我们保留随时更改此政策的权利。
覆盖了什么内容?
仅涵盖公共的 JAX API,包括以下模块:
jax
jax.dlpack
jax.image
jax.lax
jax.nn
jax.numpy
jax.ops
jax.profiler
jax.random (参见下文详细说明)
jax.scipy
jax.tree_util
jax.test_util
这些模块中并非所有内容都是公开的。随着时间的推移,我们正在努力区分公共 API 和私有 API。公共 API 在 JAX 文档中有详细记录。此外,我们的目标是所有非公共 API 应以下划线作为前缀命名,尽管我们目前还未完全遵守这一规定。
未覆盖的内容是什么?
任何以下划线开头的内容。
jax._src
jax.core
jax.linear_util
jax.lib
jax.prng
jax.interpreters
jax.experimental
jax.example_libraries
jax.extend (参见详情)
此列表并非详尽无遗。
数值和随机性
数值运算的确切值在 JAX 的不同版本中并不保证稳定。事实上,在给定的 JAX 版本、加速器平台上,在或不在 jax.jit 内部,等等,确切的数值计算不一定是稳定的。
对于固定 PRNG 密钥输入,jax.random 中伪随机函数的输出可能会在 JAX 不同版本间变化。兼容性政策仅适用于输出的分布。例如,表达式 jax.random.gumbel(jax.random.key(72)) 在 JAX 的不同版本中可能返回不同的值,但 jax.random.gumbel 仍然是 Gumbel 分布的伪随机生成器。
我们尽量不频繁地更改伪随机值。当更改发生时,会在变更日志中公布,但不遵循废弃周期。在某些情况下,JAX 可能会暴露一个临时配置标志,用于回滚新行为,以帮助用户诊断和更新受影响的代码。此类标志将持续一段废弃时间。
Python 和 NumPy 版本支持政策
原文:jax.readthedocs.io/en/latest/deprecation.html
对于 NumPy 和 SciPy 版本支持,JAX 遵循 Python 科学社区的 SPEC 0。
对于 Python 版本支持,我们听取了用户的意见,36 个月的支持窗口可能太短,例如由于新 CPython 版本到 Linux 供应商版本的延迟传播。因此,JAX 支持 Python 版本至少比 SPEC-0 推荐的长九个月。
这意味着我们至少支持:
在每个 JAX 发布前 45 个月内的所有较小的 Python 版本。例如:
Python 3.9于 2020 年 10 月发布,并将至少在2024 年 7 月之前支持新的 JAX 发布。
Python 3.10于 2021 年 10 月发布,并将至少在2025 年 7 月之前支持新的 JAX 发布。
Python 3.11于 2022 年 10 月发布,并将至少在2026 年 7 月之前支持新的 JAX 发布。
在每个 JAX 发布前 24 个月内的所有较小的 NumPy 版本。例如:
NumPy 1.22于 2021 年 12 月发布,并将至少在2023 年 12 月之前支持新的 JAX 发布。
NumPy 1.23于 2022 年 6 月发布,并将至少在2024 年 6 月之前支持新的 JAX 发布。
NumPy 1.24于 2022 年 12 月发布,并将至少在2024 年 12 月之前支持新的 JAX 发布。
在每个 JAX 发布前 24 个月内的所有较小的 SciPy 版本,从 SciPy 版本 1.9 开始。例如:
Scipy 1.9于 2022 年 7 月发布,并将至少在2024 年 7 月之前支持新的 JAX 发布。
Scipy 1.10于 2023 年 1 月发布,并将至少在2025 年 1 月之前支持新的 JAX 发布。
Scipy 1.11于 2023 年 6 月发布,并将至少在2025 年 6 月之前支持新的 JAX 发布。
JAX 发布可以支持比本政策严格要求的更旧的 Python、NumPy 和 SciPy 版本,但对更旧版本的支持可能随时在列出的日期之后终止。
jax.Array 迁移
原文:jax.readthedocs.io/en/latest/jax_array_migration.html
yashkatariya@
TL;DR
JAX 将其默认数组实现切换为新的 jax.Array 自版本 0.4.1 起。本指南解释了这一决定的背景,它可能对您的代码产生的影响,以及如何(临时)切换回旧行为。
发生了什么?
jax.Array 是 JAX 中统一的数组类型,包括 DeviceArray、ShardedDeviceArray 和 GlobalDeviceArray 类型。jax.Array 类型有助于使并行成为 JAX 的核心特性,简化和统一了 JAX 的内部结构,并允许我们统一 jit 和 pjit。如果你的代码没有涉及到 DeviceArray、ShardedDeviceArray 和 GlobalDeviceArray 的区别,那就不需要进行任何更改。但是依赖于这些单独类细节的代码可能需要进行调整以适配统一的 jax.Array。
迁移完成后,jax.Array 将成为 JAX 中唯一的数组类型。
本文介绍了如何将现有代码库迁移到 jax.Array。有关如何使用 jax.Array 和 JAX 并行 API 的更多信息,请参阅 Distributed arrays and automatic parallelization 教程。
如何启用 jax.Array?
你可以通过以下方式启用 jax.Array:
设置 shell 环境变量 JAX_ARRAY 为真值(例如 1);
如果你的代码使用 absl 解析标志,可以将布尔标志 jax_array 设置为真值;
在你的主文件顶部加入以下声明:
import jax
jax.config.update('jax_array', True)
如何判断 jax.Array 是否破坏了我的代码?
最简单的方法是禁用 jax.Array,看看问题是否解决。
我如何暂时禁用 jax.Array?
通过 2023 年 3 月 15 日,可以通过以下方式禁用 jax.Array:
设置 shell 环境变量 JAX_ARRAY 为假值(例如 0);
如果你的代码使用 absl 解析标志,可以将布尔标志 jax_array 设置为假值;
在你的主文件顶部加入以下声明:
import jax
jax.config.update('jax_array', False)
为什么创建 jax.Array?
当前 JAX 有三种类型:DeviceArray、ShardedDeviceArray 和 GlobalDeviceArray。jax.Array 合并了这三种类型,并清理了 JAX 的内部结构,同时增加了新的并行特性。
我们还引入了一个新的 Sharding 抽象,描述了逻辑数组如何在一个或多个设备(如 TPU 或 GPU)上物理分片。这一变更还升级、简化并将 pjit 的并行性特性合并到 jit 中。使用 jit 装饰的函数将能够在分片数组上操作,而无需将数据复制到单个设备上。
使用 jax.Array 可以获得的功能:
C++ pjit 分派路径
逐操作并行性(即使数组分布在多台设备上,跨多个主机)
使用 pjit/jit 更简单的批数据并行性。
可以完全利用 OpSharding 的灵活性,或者任何您想要的其他分片方式来创建不一定包含网格和分区规范的 Sharding。
等等
示例:
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import numpy as np
x = jnp.arange(8)
# Let's say there are 8 devices in jax.devices()
mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
sharded_x = jax.device_put(x, sharding)
# `matmul_sharded_x` and `sin_sharded_x` are sharded. `jit` is able to operate over a
# sharded array without copying data to a single device.
matmul_sharded_x = sharded_x @ sharded_x.T
sin_sharded_x = jnp.sin(sharded_x)
# Even jnp.copy preserves the sharding on the output.
copy_sharded_x = jnp.copy(sharded_x)
# double_out is also sharded
double_out = jax.jit(lambda x: x * 2)(sharded_x)
切换到 jax.Array 后可能会出现哪些问题?
新公共类型命名为 jax.Array。
所有 isinstance(..., jnp.DeviceArray) 或 isinstance(.., jax.xla.DeviceArray) 以及其他 DeviceArray 的变体应该切换到使用 isinstance(..., jax.Array)。
由于 jax.Array 可以表示 DA、SDA 和 GDA,您可以通过以下方式在 jax.Array 中区分这三种类型:
x.is_fully_addressable and len(x.sharding.device_set) == 1 – 这意味着 jax.Array 类似于 DA。
x.is_fully_addressable and (len(x.sharding.device_set) > 1 – 这意味着 jax.Array 类似于 SDA。
not x.is_fully_addressable – 这意味着 jax.Array 类似于 GDA,并跨多个进程。
对于 ShardedDeviceArray,可以将 isinstance(..., pxla.ShardedDeviceArray) 转移到 isinstance(..., jax.Array) and x.is_fully_addressable and len(x.sharding.device_set) > 1。
通常无法区分单设备数组上的 ShardedDeviceArray 与任何其他类型的单设备数组。
GDA 的 API 名称变更
GDA 的 local_shards 和 local_data 已经被弃用。
请使用与 jax.Array 和 GDA 兼容的 addressable_shards 和 addressable_data。
创建 jax.Array。
当 jax_array 标志为真时,所有 JAX 函数将输出 jax.Array。如果您曾使用 GlobalDeviceArray.from_callback、make_sharded_device_array 或 make_device_array 函数显式创建相应的 JAX 数据类型,则需要切换为使用 jax.make_array_from_callback() 或 jax.make_array_from_single_device_arrays()。
对于 GDA:
GlobalDeviceArray.from_callback(shape, mesh, pspec, callback) 可以一对一地切换为 jax.make_array_from_callback(shape, jax.sharding.NamedSharding(mesh, pspec), callback)。
如果您曾使用原始的 GDA 构造函数来创建 GDAs,则执行以下操作:
GlobalDeviceArray(shape, mesh, pspec, buffers) 可以变成 jax.make_array_from_single_device_arrays(shape, jax.sharding.NamedSharding(mesh, pspec), buffers)。
对于 SDA:
make_sharded_device_array(aval, sharding_spec, device_buffers, indices) 可以变成 jax.make_array_from_single_device_arrays(shape, sharding, device_buffers)。
要决定分片应该是什么,取决于您创建 SDA 的原因:
如果它被创建为 pmap 的输入,则分片可以是:jax.sharding.PmapSharding(devices, sharding_spec)。
如果它被创建为 pjit 的输入,则分片可以是 jax.sharding.NamedSharding(mesh, pspec)。
切换到 jax.Array 后对于主机本地输入的 pjit 有破坏性变更。
如果您完全使用 GDA 参数作为 pjit 的输入,则可以跳过此部分! 🎉
启用jax.Array后,所有传递给pjit的输入必须是全局形状的。这是与之前行为不兼容的变化,之前的pjit会将进程本地的参数连接成一个全局值;现在不再进行此连接。
为什么我们要进行这个突破性的变化?现在每个数组都明确说明了它的本地分片如何适合全局整体,而不是留下隐含的情况。更明确的表示方式还可以解锁额外的灵活性,例如在某些 TPU 模型上可以提高效率的非连续网格使用pjit。
在启用jax.Array时,运行多进程 pjit 计算并在传递主机本地输入时可能会导致类似以下错误:
示例:
Mesh = {'x': 2, 'y': 2, 'z': 2} 和主机本地输入形状 == (4,) 以及pspec = P(('x', 'y', 'z'))
因为pjit不会将主机本地形状提升为全局形状,所以您会收到以下错误:
注意:只有当您的主机本地形状小于网格的形状时,才会看到此错误。
ValueError: One of pjit arguments was given the sharding of
NamedSharding(mesh={'x': 2, 'y': 2, 'chips': 2}, partition_spec=PartitionSpec(('x', 'y', 'chips'),)),
which implies that the global size of its dimension 0 should be divisible by 8,
but it is equal to 4
错误出现是因为当维度0上的值为4时,无法将其分片成 8 份。
如果你仍然将主机本地输入传递给pjit,如何迁移?我们提供了过渡 API 来帮助您迁移:
注意:如果您在单进程上运行pjit计算,则不需要这些实用程序。
from jax.experimental import multihost_utils
global_inps = multihost_utils.host_local_array_to_global_array(
local_inputs, mesh, in_pspecs)
global_outputs = pjit(f, in_shardings=in_pspecs,
out_shardings=out_pspecs)(global_inps)
local_outs = multihost_utils.global_array_to_host_local_array(
global_outputs, mesh, out_pspecs)
host_local_array_to_global_array是一种类型转换,它查看具有仅本地分片的值,并将其本地形状更改为在更改之前如果传递该值pjit会假定的形状。
支持完全复制的输入,即每个进程上具有相同形状,并且in_axis_resources为P(None)的情况。在这种情况下,您无需使用host_local_array_to_global_array,因为形状已经是全局的。
key = jax.random.PRNGKey(1)
# As you can see, using host_local_array_to_global_array is not required since in_axis_resources says
# that the input is fully replicated via P(None)
pjit(f, in_shardings=None, out_shardings=None)(key)
# Mixing inputs
global_inp = multihost_utils.host_local_array_to_global_array(
local_inp, mesh, P('data'))
global_out = pjit(f, in_shardings=(P(None), P('data')),
out_shardings=...)(key, global_inp)
FROM_GDA和jax.Array
如果你在in_axis_resources参数中使用FROM_GDA来传递给pjit,那么在使用jax.Array时,无需向in_axis_resources传递任何内容,因为jax.Array将遵循计算遵循分片的语义。
例如:
pjit(f, in_shardings=FROM_GDA, out_shardings=...) can be replaced by pjit(f, out_shardings=...)
如果你的输入中混合了PartitionSpecs和FROM_GDA,例如 numpy 数组等,则使用host_local_array_to_global_array将它们转换为jax.Array。
例如:
如果你有这样的情况:
pjitted_f = pjit(
f, in_shardings=(FROM_GDA, P('x'), FROM_GDA, P(None)),
out_shardings=...)
pjitted_f(gda1, np_array1, gda2, np_array2)
然后您可以将其替换为:
pjitted_f = pjit(f, out_shardings=...)
array2, array3 = multihost_utils.host_local_array_to_global_array(
(np_array1, np_array2), mesh, (P('x'), P(None)))
pjitted_f(array1, array2, array3, array4)
live_buffers替换为live_arrays。
jax Device上的live_buffers属性已被弃用。请改用与jax.Array兼容的jax.live_arrays()。
处理向pjit传递的主机本地输入,例如批次等。
如果在多进程环境中向pjit传递主机本地输入,请使用multihost_utils.host_local_array_to_global_array将批次转换为全局jax.Array,然后将其传递给pjit。
这种主机本地输入最常见的例子是输入数据批次。
这对任何主机本地输入都有效(不仅仅是输入数据批次)。
from jax.experimental import multihost_utils
batch = multihost_utils.host_local_array_to_global_array(
batch, mesh, batch_partition_spec)
关于这种变化以及更多示例,请参阅上面的 pjit 部分。
RecursionError:递归调用 jit 时发生的错误。
当你的代码的某部分禁用了 jax.Array,然后你仅在其他部分启用它时会出现这种情况。例如,如果你使用某些第三方代码,该代码已禁用了 jax.Array 并从该库获得一个 DeviceArray,然后在你的库中启用 jax.Array 并将该 DeviceArray 传递给 JAX 函数,就会导致 RecursionError。
当 jax.Array 默认启用时,所有库都返回 jax.Array,除非显式禁用它,这个错误就应该消失。
异步调度
原文:jax.readthedocs.io/en/latest/async_dispatch.html
JAX 使用异步调度来隐藏 Python 的开销。考虑以下程序:
>>> import numpy as np
>>> import jax.numpy as jnp
>>> from jax import random
>>> x = random.uniform(random.key(0), (1000, 1000))
>>> # Printing the result (i.e. evaluating `repr(result)` or `str(result)`)
>>> # will block until the value is ready.
>>> jnp.dot(x, x) + 3.
Array([[258.01971436, 249.64862061, 257.13372803, ...,
236.67948914, 250.68939209, 241.36853027],
[265.65979004, 256.28912354, 262.18252563, ...,
242.03181458, 256.16757202, 252.44122314],
[262.38916016, 255.72747803, 261.23059082, ...,
240.83563232, 255.41094971, 249.62471008],
...,
[259.15814209, 253.09197998, 257.72174072, ...,
242.23876953, 250.72680664, 247.16642761],
[271.22662354, 261.91204834, 265.33398438, ...,
248.26651001, 262.05389404, 261.33700562],
[257.16134644, 254.7543335, 259.08300781, ..., 241.59848022,
248.62597656, 243.22348022]], dtype=float32)
当执行诸如 jnp.dot(x, x) 这样的操作时,JAX 不会等待操作完成再将控制返回给 Python 程序。相反,JAX 返回一个 jax.Array 值,它是一个未来的值,即将来在加速设备上生成但不一定立即可用的值。我们可以检查 jax.Array 的形状或类型,而无需等待生成它的计算完成,并且甚至可以将其传递给另一个 JAX 计算,正如我们在此处执行加法操作一样。只有当我们实际从主机检查数组的值时,例如通过打印它或将其转换为普通的 numpy.ndarray,JAX 才会强制 Python 代码等待计算完成。
异步调度非常有用,因为它允许 Python 代码在加速设备之前“超前运行”,从而避免 Python 代码进入关键路径。只要 Python 代码将工作快速地加入设备的队列,比它执行得更快,并且只要 Python 代码实际上不需要检查主机上的计算输出,那么 Python 程序就可以加入任意量的工作并避免让加速器等待。
异步调度对微基准测试有一个稍显意外的影响。
>>> %time jnp.dot(x, x)
CPU times: user 267 µs, sys: 93 µs, total: 360 µs
Wall time: 269 µs
Array([[255.01972961, 246.64862061, 254.13371277, ...,
233.67948914, 247.68939209, 238.36853027],
[262.65979004, 253.28910828, 259.18252563, ...,
239.03181458, 253.16757202, 249.44122314],
[259.38916016, 252.72747803, 258.23059082, ...,
237.83563232, 252.41094971, 246.62471008],
...,
[256.15814209, 250.09197998, 254.72172546, ...,
239.23876953, 247.72680664, 244.16642761],
[268.22662354, 258.91204834, 262.33398438, ...,
245.26651001, 259.05389404, 258.33700562],
[254.16134644, 251.7543335, 256.08300781, ..., 238.59848022,
245.62597656, 240.22348022]], dtype=float32)
对于在 CPU 上进行的 1000x1000 矩阵乘法来说,269µs 的时间是一个令人惊讶地小的时间!然而,事实证明异步调度在误导我们,我们并没有计时矩阵乘法的执行,而是调度工作的时间。要测量操作的真正成本,我们必须要么在主机上读取值(例如,将其转换为普通的主机端 numpy 数组),要么在 jax.Array 值上使用 block_until_ready() 方法,等待生成它的计算完成。
>>> %time np.asarray(jnp.dot(x, x))
CPU times: user 61.1 ms, sys: 0 ns, total: 61.1 ms
Wall time: 8.09 ms
Out[16]:
array([[255.01973, 246.64862, 254.13371, ..., 233.67949, 247.68939,
238.36853],
[262.6598 , 253.28911, 259.18253, ..., 239.03181, 253.16757,
249.44122],
[259.38916, 252.72748, 258.2306 , ..., 237.83563, 252.41095,
246.62471],
...,
[256.15814, 250.09198, 254.72173, ..., 239.23877, 247.7268 ,
244.16643],
[268.22662, 258.91205, 262.33398, ..., 245.26651, 259.0539 ,
258.337 ],
[254.16135, 251.75433, 256.083 , ..., 238.59848, 245.62598,
240.22348]], dtype=float32)
>>> %time jnp.dot(x, x).block_until_ready()
CPU times: user 50.3 ms, sys: 928 µs, total: 51.2 ms
Wall time: 4.92 ms
Array([[255.01972961, 246.64862061, 254.13371277, ...,
233.67948914, 247.68939209, 238.36853027],
[262.65979004, 253.28910828, 259.18252563, ...,
239.03181458, 253.16757202, 249.44122314],
[259.38916016, 252.72747803, 258.23059082, ...,
237.83563232, 252.41094971, 246.62471008],
...,
[256.15814209, 250.09197998, 254.72172546, ...,
239.23876953, 247.72680664, 244.16642761],
[268.22662354, 258.91204834, 262.33398438, ...,
245.26651001, 259.05389404, 258.33700562],
[254.16134644, 251.7543335, 256.08300781, ..., 238.59848022,
245.62597656, 240.22348022]], dtype=float32)
在不将结果转移到 Python 的情况下进行阻塞通常更快,通常是编写计算时间微基准测试时的最佳选择。
并发
JAX 并发
JAX 对 Python 并发的支持有限。
客户端可以从不同的 Python 线程并发调用 JAX API(例如,jit() 或 grad())。
不允许同时从多个线程并发地操作 JAX 追踪值。换句话说,虽然可以从多个线程调用使用 JAX 追踪的函数(例如 jit()),但不得使用线程来操作传递给 jit() 的函数 f 实现内部的 JAX 值。如果这样做,最有可能的结果是 JAX 报告一个神秘的错误。
GPU 内存分配
原文:jax.readthedocs.io/en/latest/gpu_memory_allocation.html
当第一个 JAX 操作运行时,JAX 将预先分配总 GPU 内存的 75%。 预先分配可以最小化分配开销和内存碎片化,但有时会导致内存不足(OOM)错误。如果您的 JAX 进程因内存不足而失败,可以使用以下环境变量来覆盖默认行为:
XLA_PYTHON_CLIENT_PREALLOCATE=false
这将禁用预分配行为。JAX 将根据需要分配 GPU 内存,可能会减少总体内存使用。但是,这种行为更容易导致 GPU 内存碎片化,这意味着使用大部分可用 GPU 内存的 JAX 程序可能会在禁用预分配时发生 OOM。
XLA_PYTHON_CLIENT_MEM_FRACTION=.XX
如果启用了预分配,这将使 JAX 预分配总 GPU 内存的 XX% ,而不是默认的 75%。减少预分配量可以修复 JAX 程序启动时的内存不足问题。
XLA_PYTHON_CLIENT_ALLOCATOR=platform
这使得 JAX 根据需求精确分配内存,并释放不再需要的内存(请注意,这是唯一会释放 GPU 内存而不是重用它的配置)。这样做非常慢,因此不建议用于一般用途,但可能对于以最小可能的 GPU 内存占用运行或调试 OOM 失败非常有用。
OOM 失败的常见原因
同时运行多个 JAX 进程。
要么使用 XLA_PYTHON_CLIENT_MEM_FRACTION 为每个进程分配适当的内存量,要么设置 XLA_PYTHON_CLIENT_PREALLOCATE=false。
同时运行 JAX 和 GPU TensorFlow。
TensorFlow 默认也会预分配,因此这与同时运行多个 JAX 进程类似。
一个解决方案是仅使用 CPU TensorFlow(例如,如果您仅使用 TF 进行数据加载)。您可以使用命令 tf.config.experimental.set_visible_devices([], "GPU") 阻止 TensorFlow 使用 GPU。
或者,使用 XLA_PYTHON_CLIENT_MEM_FRACTION 或 XLA_PYTHON_CLIENT_PREALLOCATE。还有类似的选项可以配置 TensorFlow 的 GPU 内存分配(gpu_memory_fraction 和 allow_growth 在 TF1 中应该设置在传递给 tf.Session 的 tf.ConfigProto 中。参见 使用 GPU:限制 GPU 内存增长 用于 TF2)。
在显示 GPU 上运行 JAX。
使用 XLA_PYTHON_CLIENT_MEM_FRACTION 或 XLA_PYTHON_CLIENT_PREALLOCATE。
提升秩警告
原文:jax.readthedocs.io/en/latest/rank_promotion_warning.html
NumPy 广播规则 允许自动将参数从一个秩(数组轴的数量)提升到另一个秩。当意图明确时,此行为很方便,但也可能导致意外的错误,其中静默的秩提升掩盖了潜在的形状错误。
下面是提升秩的示例:
>>> import numpy as np
>>> x = np.arange(12).reshape(4, 3)
>>> y = np.array([0, 1, 0])
>>> x + y
array([[ 0, 2, 2],
[ 3, 5, 5],
[ 6, 8, 8],
[ 9, 11, 11]])
为了避免潜在的意外,jax.numpy 可配置,以便需要提升秩的表达式会导致警告、错误或像常规 NumPy 一样允许。配置选项名为 jax_numpy_rank_promotion,可以取字符串值 allow、warn 和 raise。默认设置为 allow,允许提升秩而不警告或错误。设置为 raise 则在提升秩时引发错误,而 warn 在首次提升秩时引发警告。
可以使用 jax.numpy_rank_promotion() 上下文管理器在本地启用或禁用提升秩:
with jax.numpy_rank_promotion("warn"):
z = x + y
这个配置也可以在多种全局方式下设置。其中一种是在代码中使用 jax.config:
import jax
jax.config.update("jax_numpy_rank_promotion", "warn")
也可以使用环境变量 JAX_NUMPY_RANK_PROMOTION 来设置选项,例如 JAX_NUMPY_RANK_PROMOTION='warn'。最后,在使用 absl-py 时,可以使用命令行标志设置选项。
公共 API:jax 包
原文:jax.readthedocs.io/en/latest/jax.html
子包
jax.numpy 模块
jax.scipy 模块
jax.lax 模块
jax.random 模块
jax.sharding 模块
jax.debug 模块
jax.dlpack 模块
jax.distributed 模块
jax.dtypes 模块
jax.flatten_util 模块
jax.image 模块
jax.nn 模块
jax.ops 模块
jax.profiler 模块
jax.stages 模块
jax.tree 模块
jax.tree_util 模块
jax.typing 模块
jax.export 模块
jax.extend 模块
jax.example_libraries 模块
jax.experimental 模块
配置
config
check_tracer_leaks
jax_check_tracer_leaks 配置选项的上下文管理器。
checking_leaks
jax_check_tracer_leaks 配置选项的上下文管理器。
debug_nans
jax_debug_nans 配置选项的上下文管理器。
debug_infs
jax_debug_infs 配置选项的上下文管理器。
default_device
jax_default_device 配置选项的上下文管理器。
default_matmul_precision
jax_default_matmul_precision 配置选项的上下文管理器。
default_prng_impl
jax_default_prng_impl 配置选项的上下文管理器。
enable_checks
jax_enable_checks 配置选项的上下文管理器。
enable_custom_prng
jax_enable_custom_prng 配置选项的上下文管理器(临时)。
enable_custom_vjp_by_custom_transpose
jax_enable_custom_vjp_by_custom_transpose 配置选项的上下文管理器(临时)。
log_compiles
jax_log_compiles 配置选项的上下文管理器。
numpy_rank_promotion
jax_numpy_rank_promotion 配置选项的上下文管理器。
transfer_guard(new_val)
控制所有传输的传输保护级别的上下文管理器。
即时编译 (jit)
jit(fun[, in_shardings, out_shardings, ...])
使用 XLA 设置 fun 进行即时编译。
disable_jit([disable])
禁用其动态上下文下 jit() 行为的上下文管理器。
ensure_compile_time_eval()
确保在追踪/编译时进行评估的上下文管理器(或错误)。
xla_computation(fun[, static_argnums, ...])
创建一个函数,给定示例参数,产生其 XLA 计算。
make_jaxpr([axis_env, return_shape, ...])
创建一个函数,给定示例参数,产生其 jaxpr。
eval_shape(fun, *args, **kwargs)
计算 fun 的形状/数据类型,不进行任何 FLOP 计算。
ShapeDtypeStruct(shape, dtype[, ...])
数组的形状、dtype 和其他静态属性的容器。
device_put(x[, device, src])
将 x 传输到 device。
device_put_replicated(x, devices)
将数组传输到每个指定的设备并形成数组。
device_put_sharded(shards, devices)
将数组片段传输到指定设备并形成数组。
device_get(x)
将 x 传输到主机。
default_backend()
返回默认 XLA 后端的平台名称。
named_call(fun, *[, name])
在 JAX 计算中给函数添加用户指定的名称。
named_scope(name)
将用户指定的名称添加到 JAX 名称堆栈的上下文管理器。
| block_until_ready(x) | 尝试调用 pytree 叶子上的 block_until_ready 方法。 | ## 自动微分
grad(fun[, argnums, has_aux, holomorphic, ...])
创建一个评估 fun 梯度的函数。
value_and_grad(fun[, argnums, has_aux, ...])
创建一个同时评估 fun 和 fun 梯度的函数。
jacfwd(fun[, argnums, has_aux, holomorphic])
使用正向模式自动微分逐列计算 fun 的雅可比矩阵。
jacrev(fun[, argnums, has_aux, holomorphic, ...])
使用反向模式自动微分逐行计算 fun 的雅可比矩阵。
hessian(fun[, argnums, has_aux, holomorphic])
fun 的 Hessian 矩阵作为稠密数组。
jvp(fun, primals, tangents[, has_aux])
计算 fun 的(正向模式)雅可比向量乘积。
linearize()
使用 jvp() 和部分求值生成对 fun 的线性近似。
linear_transpose(fun, *primals[, reduce_axes])
转置一个承诺为线性的函数。
vjp() ))
计算 fun 的(反向模式)向量-Jacobian 乘积。
custom_jvp(fun[, nondiff_argnums])
为自定义 JVP 规则定义一个可 JAX 化的函数。
custom_vjp(fun[, nondiff_argnums])
为自定义 VJP 规则定义一个可 JAX 化的函数。
custom_gradient(fun)
方便地定义自定义的 VJP 规则(即自定义梯度)。
closure_convert(fun, *example_args)
闭包转换实用程序,用于与高阶自定义导数一起使用。
checkpoint(fun, *[, prevent_cse, policy, ...])
使 fun 在求导时重新计算内部线性化点。
jax.Array (jax.Array)
Array()
JAX 的数组基类
make_array_from_callback(shape, sharding, ...)
通过从 data_callback 获取的数据返回一个 jax.Array。
make_array_from_single_device_arrays(shape, ...)
从每个位于单个设备上的 jax.Array 序列返回一个 jax.Array。
make_array_from_process_local_data(sharding, ...)
使用进程中可用的数据创建分布式张量。
向量化 (vmap)
vmap(fun[, in_axes, out_axes, axis_name, ...])
向量化映射。
numpy.vectorize(pyfunc, *[, excluded, signature])
定义一个支持广播的向量化函数。
并行化 (pmap)
pmap(fun[, axis_name, in_axes, out_axes, ...])
支持集体操作的并行映射。
devices([backend])
返回给定后端的所有设备列表。
local_devices([process_index, backend, host_id])
类似于 jax.devices(),但仅返回给定进程局部的设备。
process_index([backend])
返回此进程的整数进程索引。
device_count([backend])
返回设备的总数。
local_device_count([backend])
返回此进程可寻址的设备数量。
process_count([backend])
返回与后端关联的 JAX 进程数。
Callbacks
pure_callback(callback, result_shape_dtypes, ...)
调用一个纯 Python 回调函数。
experimental.io_callback(callback, ...[, ...])
调用一个非纯 Python 回调函数。
debug.callback(callback, *args[, ordered])
调用一个可分期的 Python 回调函数。
debug.print(fmt, *args[, ordered])
打印值,并在分期 JAX 函数中工作。
Miscellaneous
Device
可用设备的描述符。
print_environment_info([return_string])
返回一个包含本地环境和 JAX 安装信息的字符串。
live_arrays([platform])
返回后端平台上的所有活动数组。
clear_caches()
清除所有编译和分期缓存。
jax.numpy 模块
原文:jax.readthedocs.io/en/latest/jax.numpy.html
采用jax.lax中的原语实现 NumPy API。
虽然 JAX 尽可能地遵循 NumPy API,但有时无法完全遵循 NumPy 的规范。
值得注意的是,由于 JAX 数组是不可变的,不能在 JAX 中实现原地变换数组的 NumPy API。但是,JAX 通常能够提供纯函数的替代 API。例如,替代原地数组更新(x[i] = y),JAX 提供了一个纯索引更新函数 x.at[i].set(y)(参见ndarray.at)。
类似地,一些 NumPy 函数在可能时经常返回数组的视图(例如transpose()和reshape())。JAX 版本的这类函数将返回副本,尽管在使用jax.jit()编译操作序列时,XLA 通常会进行优化。
NumPy 在将值提升为float64类型时非常积极。JAX 在类型提升方面有时不那么积极(请参阅类型提升语义)。
一些 NumPy 例程具有依赖数据的输出形状(例如unique()和nonzero())。因为 XLA 编译器要求在编译时知道数组形状,这些操作与 JIT 不兼容。因此,JAX 在这些函数中添加了一个可选的size参数,可以在静态指定以便与 JIT 一起使用。
几乎所有适用的 NumPy 函数都在jax.numpy命名空间中实现;它们如下所列。
ndarray.at
用于索引更新功能的辅助属性。
abs(x, /)
jax.numpy.absolute()的别名。
absolute(x, /)
计算逐元素的绝对值。
acos(x, /)
逐元素的反余弦函数。
acosh(x, /)
逐元素的反双曲余弦函数。
add(x1, x2, /)
逐元素相加。
all(a[, axis, out, keepdims, where])
测试沿给定轴的所有数组元素是否为 True。
allclose(a, b[, rtol, atol, equal_nan])
如果两个数组在容差范围内逐元素相等,则返回 True。
amax(a[, axis, out, keepdims, initial, where])
返回数组或沿轴的最大值。
amin(a[, axis, out, keepdims, initial, where])
返回数组或沿轴的最小值。
angle(z[, deg])
返回复数或数组的角度。
any(a[, axis, out, keepdims, where])
测试沿给定轴的任何数组元素是否为 True。
append(arr, values[, axis])
返回将值附加到原始数组末尾的新数组。
apply_along_axis(func1d, axis, arr, *args, ...)
沿给定轴向数组的 1-D 切片应用函数。
apply_over_axes(func, a, axes)
在多个轴上重复应用函数。
arange(start[, stop, step, dtype])
返回给定间隔内的均匀间隔值。
arccos(x, /)
反余弦,逐元素计算。
arccosh(x, /)
逆双曲余弦,逐元素计算。
arcsin(x, /)
反正弦,逐元素计算。
arcsinh(x, /)
逆双曲正弦,逐元素计算。
arctan(x, /)
反三角正切,逐元素计算。
arctan2(x1, x2, /)
根据 x1/x2 的值选择正确的象限,逐元素计算反正切。
arctanh(x, /)
逆双曲正切,逐元素计算。
argmax(a[, axis, out, keepdims])
返回沿轴的最大值的索引。
argmin(a[, axis, out, keepdims])
返回沿轴的最小值的索引。
argpartition(a, kth[, axis])
返回部分排序数组的索引。
argsort(a[, axis, kind, order, stable, ...])
返回排序数组的索引。
argwhere(a, *[, size, fill_value])
查找非零数组元素的索引。
around(a[, decimals, out])
将数组四舍五入到指定的小数位数。
array(object[, dtype, copy, order, ndmin])
创建一个数组。
array_equal(a1, a2[, equal_nan])
如果两个数组具有相同的形状和元素则返回 True。
array_equiv(a1, a2)
如果输入数组形状一致且所有元素相等则返回 True。
array_repr(arr[, max_line_width, precision, ...])
返回数组的字符串表示。
array_split(ary, indices_or_sections[, axis])
将数组分割为多个子数组。
array_str(a[, max_line_width, precision, ...])
返回数组中数据的字符串表示。
asarray(a[, dtype, order, copy])
将输入转换为数组。
asin(x, /)
反正弦,逐元素计算。
asinh(x, /)
逆双曲正弦,逐元素计算。
astype(x, dtype, /, *[, copy, device])
将数组复制到指定的数据类型。
atan(x, /)
反三角正切,逐元素计算。
atanh(x, /)
逆双曲正切,逐元素计算。
atan2(x1, x2, /)
根据 x1/x2 的值选择正确的象限,逐元素计算反正切。
atleast_1d()
将输入转换为至少有一维的数组。
atleast_2d()
将输入视为至少有两个维度的数组。
atleast_3d()
将输入视为至少有三个维度的数组。
average()
沿指定轴计算加权平均值。
bartlett(M)
返回 Bartlett 窗口。
bincount(x[, weights, minlength, length])
计算整数数组中每个值的出现次数。
bitwise_and(x1, x2, /)
逐元素计算两个数组的按位与操作。
bitwise_count(x, /)
计算每个元素的绝对值的二进制表示中 1 的位数。
bitwise_invert(x, /)
计算按位求反,逐元素计算。
bitwise_left_shift(x1, x2, /)
将整数的位向左移动。
bitwise_not(x, /)
计算按位取反(bit-wise NOT),即按位取反,对每个元素进行操作。
bitwise_or(x1, x2, /)
计算两个数组按位或的结果。
bitwise_right_shift(x1, x2, /)
将整数的位向右移动。
bitwise_xor(x1, x2, /)
计算两个数组按位异或的结果。
blackman(M)
返回 Blackman 窗口。
block(arrays)
从嵌套的块列表中组装一个多维数组。
bool_
bool 的别名
broadcast_arrays(*args)
广播任意数量的数组。
broadcast_shapes()
将输入的形状广播为单个形状。
broadcast_to(array, shape)
将数组广播到新的形状。
c_
沿着最后一个轴连接切片、标量和类数组对象。
can_cast(from_, to[, casting])
根据转换规则,如果可以进行数据类型转换,则返回 True。
cbrt(x, /)
返回数组的立方根,按元素操作。
cdouble
complex128 的别名
ceil(x, /)
返回输入的上限值,按元素操作。
character()
所有字符字符串标量类型的抽象基类。
choose(a, choices[, out, mode])
根据索引数组和数组列表选择构造数组。
clip([x, min, max, a, a_min, a_max])
将数组中的值限制在给定范围内。
column_stack(tup)
将一维数组按列堆叠成二维数组。
complex_
complex128 的别名
complex128(x)
complex64(x)
complexfloating()
所有由浮点数构成的复数数值标量类型的抽象基类。
ComplexWarning
在将复数数据类型强制转换为实数数据类型时引发的警告。
compress(condition, a[, axis, size, ...])
使用布尔条件沿指定轴压缩数组。
concat(arrays, /, *[, axis])
沿着现有轴连接一系列数组。
concatenate(arrays[, axis, dtype])
沿着指定轴连接一系列数组。
conj(x, /)
返回复数的共轭,按元素操作。
conjugate(x, /)
返回复数的共轭,按元素操作。
convolve(a, v[, mode, precision, ...])
计算两个一维数组的卷积。
copy(a[, order])
返回给定对象的数组副本。
copysign(x1, x2, /)
将 x1 的符号改为 x2 的符号,按元素操作。
corrcoef(x[, y, rowvar])
返回皮尔逊积矩相关系数。
correlate(a, v[, mode, precision, ...])
计算两个一维数组的相关性。
cos(x, /)
计算元素的余弦值。
cosh(x, /)
双曲余弦,按元素操作。
count_nonzero(a[, axis, keepdims])
统计数组a中的非零值数量。
cov(m[, y, rowvar, bias, ddof, fweights, ...])
估算给定数据和权重的协方差矩阵。
cross(a, b[, axisa, axisb, axisc, axis])
返回两个(向量)数组的叉积。
csingle
complex64的别名。
cumprod(a[, axis, dtype, out])
返回沿给定轴的元素的累积乘积。
cumsum(a[, axis, dtype, out])
返回沿给定轴的元素的累积和。
cumulative_sum(x, /, *[, axis, dtype, ...])
deg2rad(x, /)
将角度从度转换为弧度。
degrees(x, /)
将弧度从弧度转换为度。
delete(arr, obj[, axis, assume_unique_indices])
从数组中删除条目或条目。
diag(v[, k])
提取对角线或构造对角线数组。
diag_indices(n[, ndim])
返回访问数组主对角线的索引。
diag_indices_from(arr)
返回 n 维数组的主对角线的访问索引。
diagflat(v[, k])
用扁平化输入创建一个二维数组的对角线。
diagonal(a[, offset, axis1, axis2])
返回指定对角线。
diff(a[, n, axis, prepend, append])
计算给定轴的第 n 个离散差异。
digitize(x, bins[, right])
返回输入数组中每个值所属的箱体的索引。
divide(x1, x2, /)
按元素划分参数。
divmod(x1, x2, /)
同时返回按元素的商和余数。
dot(a, b, *[, precision, preferred_element_type])
计算两个数组的点积。
double
float64的别名。
dsplit(ary, indices_or_sections)
沿第 3 轴(深度)将数组分割成多个子数组。
dstack(tup[, dtype])
深度方向上序列堆叠数组(沿着第三个轴)。
dtype(dtype[, align, copy])
创建一个数据类型对象。
ediff1d(ary[, to_end, to_begin])
数组中连续元素的差异。
einsum()
爱因斯坦求和。
einsum_path()
在不评估 einsum 的情况下计算最佳收缩路径。
empty(shape[, dtype, device])
返回给定形状和类型的新数组,不初始化条目。
empty_like(prototype[, dtype, shape, device])
返回与给定数组相同形状和类型的新数组。
equal(x1, x2, /)
按元素返回(x1 == x2)。
exp(x, /)
计算输入数组中所有元素的指数。
exp2(x, /)
计算输入数组中所有 p 的 2**p。
expand_dims(a, axis)
将长度为 1 的维度插入数组。
expm1(x, /)
计算数组中所有元素的exp(x) - 1。
extract(condition, arr, *[, size, fill_value])
返回满足条件的数组元素。
eye(N[, M, k, dtype])
返回对角线上为 1 的二维数组,其他位置为 0。
fabs(x, /)
计算每个元素的绝对值。
fill_diagonal(a, val[, wrap, inplace])
填充给定任意维度数组的主对角线。
finfo(dtype)
浮点类型的机器限制。
fix(x[, out])
四舍五入到最近的整数朝向零。
flatnonzero(a, *[, size, fill_value])
返回扁平化数组中非零元素的索引。
flexible()
所有没有预定义长度的标量类型的抽象基类。
flip(m[, axis])
沿指定轴翻转数组元素的顺序。
fliplr(m)
沿轴 1 翻转数组元素的顺序。
flipud(m)
沿轴 0 翻转数组元素的顺序。
float_
float64 的别名。
float_power(x1, x2, /)
逐元素地将第一个数组的元素提升为第二个数组的幂。
float16(x)
float32(x)
float64(x)
floating()
所有浮点标量类型的抽象基类。
floor(x, /)
逐元素返回输入的下限。
floor_divide(x1, x2, /)
返回输入除法的最大整数小于或等于结果的元素。
fmax(x1, x2)
数组元素的逐元素最大值。
fmin(x1, x2)
数组元素的逐元素最小值。
fmod(x1, x2, /)
返回除法的元素余数。
frexp(x, /)
将 x 的元素分解为尾数和二次指数。
frombuffer(buffer[, dtype, count, offset])
将缓冲区解释为一维数组。
fromfile(*args, **kwargs)
jnp.fromfile 的未实现 JAX 封装器。
fromfunction(function, shape, *[, dtype])
通过对每个坐标执行函数来构造数组。
fromiter(*args, **kwargs)
jnp.fromiter 的未实现 JAX 封装器。
frompyfunc(func, /, nin, nout, *[, identity])
从任意 JAX 兼容的标量函数创建一个 JAX ufunc。
fromstring(string[, dtype, count])
从字符串中的文本数据初始化一个新的一维数组。
from_dlpack(x, /, *[, device, copy])
从实现了__dlpack__的对象创建一个 NumPy 数组。
full(shape, fill_value[, dtype, device])
返回给定形状和类型的新数组,并填充 fill_value。
full_like(a, fill_value[, dtype, shape, device])
返回与给定数组形状和类型相同的全数组。
gcd(x1, x2)
返回 |x1| 和 |x2| 的最大公约数。
generic()
NumPy 标量类型的基类。
geomspace(start, stop[, num, endpoint, ...])
返回等间隔的对数刻度上的数字(等比数列)。
get_printoptions()
返回当前的打印选项。
gradient(f, *varargs[, axis, edge_order])
返回 N 维数组的梯度。
greater(x1, x2, /)
返回逐元素 (x1 > x2) 的真值。
greater_equal(x1, x2, /)
返回逐元素 (x1 >= x2) 的真值。
hamming(M)
返回 Hamming 窗口。
hanning(M)
返回 Hanning 窗口。
heaviside(x1, x2, /)
计算 Heaviside 阶跃函数。
histogram(a[, bins, range, weights, density])
计算数据集的直方图。
histogram_bin_edges(a[, bins, range, weights])
计算直方图使用的箱子的边缘。
histogram2d(x, y[, bins, range, weights, ...])
计算两个数据样本的二维直方图。
histogramdd(sample[, bins, range, weights, ...])
计算一些数据的多维直方图。
hsplit(ary, indices_or_sections)
水平(按列)将数组分割为多个子数组。
hstack(tup[, dtype])
按序列水平(按列)堆叠数组。
hypot(x1, x2, /)
给定直角三角形的“腿”,返回其斜边长度。
i0
第一类修正贝塞尔函数,阶数为 0。
identity(n[, dtype])
返回单位数组。
iinfo(int_type)
imag(val, /)
返回复数参数的虚部。
index_exp
用于构建数组索引元组的更好方式。
indices()
返回表示网格的索引数组。
inexact()
所有数值标量类型的抽象基类,其值的表示(可能)是不精确的,如浮点数。
inner(a, b, *[, precision, ...])
计算两个数组的内积。
insert(arr, obj, values[, axis])
在给定索引之前,沿着指定的轴插入值。
int_
int64的别名
int16(x)
int32(x)
int64(x)
int8(x)
integer()
所有整数标量类型的抽象基类。
interp(x, xp, fp[, left, right, period])
单调递增样本点的一维线性插值。
intersect1d(ar1, ar2[, assume_unique, ...])
计算两个一维数组的交集。
invert(x, /)
按位求反,即按位非,逐元素进行操作。
isclose(a, b[, rtol, atol, equal_nan])
返回一个布尔数组,其中两个数组在每个元素级别上是否在指定的公差内相等。
iscomplex(x)
返回一个布尔数组,如果输入元素是复数则为 True。
iscomplexobj(x)
检查复数类型或复数数组。
isdtype(dtype, kind)
返回一个布尔值,指示提供的 dtype 是否属于指定的 kind。
isfinite(x, /)
测试每个元素是否有限(既不是无穷大也不是非数)。
isin(element, test_elements[, ...])
确定element中的元素是否出现在test_elements中。
isinf(x, /)
逐元素测试是否为正或负无穷大。
isnan(x, /)
逐元素测试是否为 NaN,并返回布尔数组结果。
isneginf(x, /[, out])
逐元素测试是否为负无穷大,返回布尔数组结果。
isposinf(x, /[, out])
逐元素测试是否为正无穷大,返回布尔数组结果。
isreal(x)
返回一个布尔数组,如果输入元素是实数则为 True。
isrealobj(x)
如果 x 是非复数类型或复数数组,则返回 True。
isscalar(element)
如果 element 的类型是标量类型,则返回 True。
issubdtype(arg1, arg2)
如果第一个参数在类型层次结构中低于或等于第二个参数的类型码,则返回 True。
iterable(y)
检查对象是否可迭代。
ix_(*args)
从 N 个一维序列返回多维网格(开放网格)。
kaiser(M, beta)
返回 Kaiser 窗口。
kron(a, b)
两个数组的 Kronecker 乘积。
lcm(x1, x2)
返回 `
ldexp(x1, x2, /)
返回 x1 * 2**x2,逐元素操作。
left_shift(x1, x2, /)
将整数的位左移。
less(x1, x2, /)
逐元素返回 (x1 < x2) 的真值。
less_equal(x1, x2, /)
逐元素返回 (x1 <= x2) 的真值。
lexsort(keys[, axis])
使用一系列键执行间接稳定排序。
linspace()
返回指定间隔内的均匀间隔数字。
load(*args, **kwargs)
从 .npy、.npz 或 pickled 文件中加载数组或序列化对象。
log(x, /)
自然对数,逐元素操作。
log10(x, /)
返回输入数组的以 10 为底的对数,逐元素操作。
log1p(x, /)
返回输入数组加 1 的自然对数,逐元素操作。
log2(x, /)
x 的以 2 为底的对数,逐元素操作。
logaddexp
输入指数的对数之和。
logaddexp2
以 2 为底的指数输入的对数之和。
logical_and(*args)
逐元素计算 x1 AND x2 的真值。
logical_not(*args)
逐元素计算 NOT x 的真值。
logical_or(*args)
逐元素计算 x1 OR x2 的真值。
logical_xor(*args)
逐元素计算 x1 XOR x2 的真值。
logspace(start, stop[, num, endpoint, base, ...])
返回对数刻度上均匀分布的数字。
mask_indices(*args, **kwargs)
给定掩码函数,返回访问 (n, n) 数组的索引。
matmul(a, b, *[, precision, ...])
执行矩阵乘法。
matrix_transpose(x, /)
转置数组的最后两个维度。
max(a[, axis, out, keepdims, initial, where])
返回数组或沿轴的最大值。
maximum(x1, x2, /)
逐元素计算数组元素的最大值。
mean(a[, axis, dtype, out, keepdims, where])
沿指定轴计算算术平均值。
median(a[, axis, out, overwrite_input, keepdims])
沿指定轴计算中位数。
meshgrid(*xi[, copy, sparse, indexing])
从坐标向量返回坐标矩阵的元组。
mgrid
返回密集的多维网格。
min(a[, axis, out, keepdims, initial, where])
返回数组或沿轴的最小值。
minimum(x1, x2, /)
逐元素计算数组元素的最小值。
mod(x1, x2, /)
返回除法的元素余数。
modf(x, /[, out])
返回数组元素的整数部分和小数部分。
moveaxis(a, source, destination)
将数组轴移动到新位置
multiply(x1, x2, /)
对参数逐元素相乘。
nan_to_num(x[, copy, nan, posinf, neginf])
将 NaN 替换为零,将无穷大替换为大的有限数(默认
nanargmax(a[, axis, out, keepdims])
返回忽略指定轴上的 NaN 的最大值的索引
nanargmin(a[, axis, out, keepdims])
返回忽略指定轴上的 NaN 的最小值的索引
nancumprod(a[, axis, dtype, out])
返回沿指定轴对数组元素的累积积,处理 NaN 为
nancumsum(a[, axis, dtype, out])
返回沿指定轴对数组元素的累积和,处理 NaN 为
nanmax(a[, axis, out, keepdims, initial, where])
返回数组或指定轴上的最大值,忽略任何 NaN
nanmean(a[, axis, dtype, out, keepdims, where])
计算沿指定轴的算术平均值,忽略 NaN
nanmedian(a[, axis, out, overwrite_input, ...])
计算沿指定轴的中位数,忽略 NaN
nanmin(a[, axis, out, keepdims, initial, where])
返回数组或指定轴上的最小值,忽略任何 NaN
nanpercentile(a, q[, axis, out, ...])
计算沿指定轴的数据的第 q 分位数,
nanprod(a[, axis, dtype, out, keepdims, ...])
返回沿指定轴对数组元素求积,处理 NaN 为
nanquantile(a, q[, axis, out, ...])
计算沿指定轴的数据的第 q 分位数,
nanstd(a[, axis, dtype, out, ddof, ...])
计算沿指定轴的标准差,忽略 NaN
nansum(a[, axis, dtype, out, keepdims, ...])
返回沿指定轴对数组元素求和,处理 NaN 为
nanvar(a[, axis, dtype, out, ddof, ...])
计算沿指定轴的方差,忽略 NaN
ndarray
Array 的别名。
ndim(a)
返回数组的维数。
negative(x, /)
数值取反,逐元素操作。
nextafter(x1, x2, /)
返回 x1 朝向 x2 的下一个浮点数值,逐元素操作。
nonzero(a, *[, size, fill_value])
返回数组中非零元素的索引。
not_equal(x1, x2, /)
逐元素返回 (x1 != x2)。
number()
所有数值标量类型的抽象基类。
object_
任何 Python 对象。
ogrid
返回开放多维“网格”。
ones(shape[, dtype, device])
返回给定形状和类型的新数组,填充为 1。
ones_like(a[, dtype, shape, device])
返回与给定数组具有相同形状和类型的填充为 1 的数组。
outer(a, b[, out])
计算两个向量的外积。
packbits(a[, axis, bitorder])
将二值数组的元素打包为 uint8 数组中的位。
pad(array, pad_width[, mode])
对数组进行填充。
partition(a, kth[, axis])
返回数组的部分排序副本。
percentile(a, q[, axis, out, ...])
计算沿指定轴的数据的第 q 个百分位数。
permute_dims(a, /, axes)
返回通过转置轴的数组。
piecewise(x, condlist, funclist, *args, **kw)
计算分段定义的函数。
place(arr, mask, vals, *[, inplace])
根据条件和输入值改变数组的元素。
poly(seq_of_zeros)
根据给定的根序列找到多项式的系数。
polyadd(a1, a2)
计算两个多项式的和。
polyder(p[, m])
返回多项式指定阶数的导数。
polydiv(u, v, *[, trim_leading_zeros])
返回多项式除法的商和余数。
polyfit(x, y, deg[, rcond, full, w, cov])
最小二乘多项式拟合。
polyint(p[, m, k])
返回多项式的不定积分(反导数)。
polymul(a1, a2, *[, trim_leading_zeros])
计算两个多项式的乘积。
polysub(a1, a2)
两个多项式的差(减法)。
polyval(p, x, *[, unroll])
在特定值处计算多项式的值。
positive(x, /)
数值的正值,逐元素操作。
pow(x1, x2, /)
将第一个数组元素按第二个数组元素的幂进行元素级操作。
power(x1, x2, /)
将第一个数组元素按第二个数组元素的幂进行元素级操作。
printoptions(*args, **kwargs)
设置打印选项的上下文管理器。
prod(a[, axis, dtype, out, keepdims, ...])
返回给定轴上数组元素的乘积。
promote_types(a, b)
返回二进制操作应将其参数转换为的类型。
ptp(a[, axis, out, keepdims])
沿某个轴的值范围(最大值 - 最小值)。
put(a, ind, v[, mode, inplace])
用给定值替换数组的指定元素。
quantile(a, q[, axis, out, overwrite_input, ...])
计算沿指定轴的数据的第 q 个分位数。
r_
沿第一个轴连接切片、标量和类数组对象。
rad2deg(x, /)
将角度从弧度转换为度。
radians(x, /)
将角度从度转换为弧度。
ravel(a[, order])
将数组展平为一维形状。
ravel_multi_index(multi_index, dims[, mode, ...])
将多维索引转换为平坦索引。
real(val, /)
返回复数参数的实部。
reciprocal(x, /)
返回参数的倒数,逐元素操作。
remainder(x1, x2, /)
返回除法的元素级余数。
repeat(a, repeats[, axis, total_repeat_length])
将数组中每个元素重复指定次数。
reshape(a[, shape, order, newshape])
返回数组的重塑副本。
resize(a, new_shape)
返回具有指定形状的新数组。
result_type(*args)
返回应用于 NumPy 的结果类型。
right_shift(x1, x2, /)
将 x1 的位向右移动到指定的 x2 量。
rint(x, /)
将数组元素四舍五入到最接近的整数。
roll(a, shift[, axis])
沿指定轴滚动数组元素。
rollaxis(a, axis[, start])
将指定的轴滚动到给定位置。
roots(p, *[, strip_zeros])
返回具有给定系数的多项式的根。
rot90(m[, k, axes])
在由轴指定的平面中将数组旋转 90 度。
round(a[, decimals, out])
将数组四舍五入到指定的小数位数。
round_(a[, decimals, out])
将数组四舍五入到指定的小数位数。
s_
用于构建数组索引元组的更好方式。
save(file, arr[, allow_pickle, fix_imports])
将数组以 NumPy .npy 格式保存到二进制文件中。
savez(file, *args, **kwds)
以未压缩的 .npz 格式将多个数组保存到单个文件中。
searchsorted(a, v[, side, sorter, method])
在排序数组内执行二分搜索。
select(condlist, choicelist[, default])
根据条件从 choicelist 中选择元素返回数组。
set_printoptions([precision, threshold, ...])
设置打印选项。
setdiff1d(ar1, ar2[, assume_unique, size, ...])
计算两个一维数组的差集。
setxor1d(ar1, ar2[, assume_unique])
计算两个数组中元素的异或。
shape(a)
返回数组的形状。
sign(x, /)
返回数的元素级别符号指示。
signbit(x, /)
返回元素级别的 True,其中设置了符号位(小于零)。
signedinteger()
所有有符号整数标量类型的抽象基类。
sin(x, /)
按元素计算三角正弦。
sinc(x, /)
返回归一化的 sinc 函数。
single
float32 的别名。
sinh(x, /)
按元素计算双曲正弦。
size(a[, axis])
返回给定轴上的元素数量。
sort(a[, axis, kind, order, stable, descending])
返回数组的排序副本。
sort_complex(a)
使用实部先排序复杂数组,然后按虚部排序。
split(ary, indices_or_sections[, axis])
将数组拆分为多个子数组,作为 ary 的视图。
sqrt(x, /)
返回数组元素的非负平方根。
square(x, /)
返回输入数组的按元素平方。
squeeze(a[, axis])
从数组中移除一个或多个长度为 1 的轴。
stack(arrays[, axis, out, dtype])
沿新轴连接序列的数组。
std(a[, axis, dtype, out, ddof, keepdims, ...])
沿指定轴计算标准差。
subtract(x1, x2, /)
逐元素地进行减法运算。
sum(a[, axis, dtype, out, keepdims, ...])
沿给定轴对数组元素求和。
swapaxes(a, axis1, axis2)
交换数组的两个轴。
take(a, indices[, axis, out, mode, ...])
从数组中取出元素。
take_along_axis(arr, indices, axis[, mode, ...])
从数组中取出元素。
tan(x, /)
计算元素的正切。
tanh(x, /)
计算元素的双曲正切。
tensordot(a, b[, axes, precision, ...])
计算两个 N 维数组的张量点积。
tile(A, reps)
通过重复 A 指定的次数构造一个数组。
trace(a[, offset, axis1, axis2, dtype, out])
返回数组的对角线之和。
trapezoid(y[, x, dx, axis])
使用复合梯形规则沿指定轴积分。
transpose(a[, axes])
返回 N 维数组的转置版本。
tri(N[, M, k, dtype])
一个在给定对角线及其以下位置为 1,其他位置为 0 的数组。
tril(m[, k])
数组的下三角形。
tril_indices(n[, k, m])
返回(n, m)数组的下三角形的索引。
tril_indices_from(arr[, k])
返回数组 arr 的下三角形的索引。
trim_zeros(filt[, trim])
从一维数组或序列中修剪前导和/或尾随的零。
triu(m[, k])
数组的上三角形。
triu_indices(n[, k, m])
返回(n, m)数组的上三角形的索引。
triu_indices_from(arr[, k])
返回数组 arr 的上三角形的索引。
true_divide(x1, x2, /)
逐元素地进行除法运算。
trunc(x)
返回输入元素的截断值。
ufunc(func, /, nin, nout, *[, name, nargs, ...])
在整个数组上逐元素操作的函数。
uint
uint64的别名。
uint16(x)
uint32(x)
uint64(x)
uint8(x)
union1d(ar1, ar2, *[, size, fill_value])
计算两个 1D 数组的并集。
unique(ar[, return_index, return_inverse, ...])
返回数组中的唯一值。
unique_all(x, /, *[, size, fill_value])
返回 x 的唯一值以及索引、逆索引和计数。
unique_counts(x, /, *[, size, fill_value])
返回 x 的唯一值及其计数。
unique_inverse(x, /, *[, size, fill_value])
返回 x 的唯一值以及索引、逆索引和计数。
unique_values(x, /, *[, size, fill_value])
返回 x 的唯一值以及索引、逆索引和计数。
unpackbits(a[, axis, count, bitorder])
将 uint8 数组的元素解包为二进制值输出数组。
unravel_index(indices, shape)
将扁平索引转换为多维索引。
unstack(x, /, *[, axis])
unsignedinteger()
所有无符号整数标量类型的抽象基类。
unwrap(p[, discont, axis, period])
通过取周期的补集来展开数组。
vander(x[, N, increasing])
生成范德蒙矩阵。
var(a[, axis, dtype, out, ddof, keepdims, ...])
计算沿指定轴的方差。
vdot(a, b, *[, precision, ...])
执行两个 1D 向量的共轭乘法。
vecdot(x1, x2, /, *[, axis, precision, ...])
执行两个批量向量的共轭乘法。
vectorize(pyfunc, *[, excluded, signature])
定义一个具有广播功能的向量化函数。
vsplit(ary, indices_or_sections)
按垂直(行)方向将数组分割成多个子数组。
vstack(tup[, dtype])
沿垂直(行)方向堆叠数组序列。
where()
根据条件从两个数组中选择元素。
zeros(shape[, dtype, device])
返回一个给定形状和类型的全零数组。
zeros_like(a[, dtype, shape, device])
返回与给定数组相同形状和类型的全零数组。
jax.numpy.fft
fft(a[, n, axis, norm])
计算一维离散傅里叶变换。
fft2(a[, s, axes, norm])
计算二维离散傅里叶变换。
fftfreq(n[, d, dtype])
返回离散傅里叶变换的样本频率。
fftn(a[, s, axes, norm])
计算 N 维离散傅里叶变换。
fftshift(x[, axes])
将零频率分量移动到频谱中心。
hfft(a[, n, axis, norm])
计算具有 Hermitian 对称性的信号的 FFT。
ifft(a[, n, axis, norm])
计算一维离散傅里叶逆变换。
ifft2(a[, s, axes, norm])
计算二维离散傅里叶逆变换。
ifftn(a[, s, axes, norm])
计算 N 维离散傅里叶逆变换。
ifftshift(x[, axes])
fftshift 的逆操作。
ihfft(a[, n, axis, norm])
计算具有 Hermitian 对称性的信号的逆 FFT。
irfft(a[, n, axis, norm])
计算 rfft 的逆变换。
irfft2(a[, s, axes, norm])
计算 rfft2 的逆变换。
irfftn(a[, s, axes, norm])
计算 rfftn 的逆变换。
rfft(a[, n, axis, norm])
计算一维实数输入的离散傅里叶变换。
rfft2(a[, s, axes, norm])
计算实数组的二维 FFT。
rfftfreq(n[, d, dtype])
返回离散傅里叶变换的样本频率。
| rfftn(a[, s, axes, norm]) | 计算实数输入的 N 维离散傅里叶变换。 | ## jax.numpy.linalg
cholesky(a, *[, upper])
计算矩阵的 Cholesky 分解。
cond(x[, p])
计算矩阵的条件数。
cross(x1, x2, /, *[, axis])
计算两个 3D 向量的叉乘。
det
计算数组的行列式。
diagonal(x, /, *[, offset])
提取矩阵或矩阵堆栈的对角线元素。
eig(a)
计算方阵的特征值和特征向量。
eigh(a[, UPLO, symmetrize_input])
计算 Hermitian 矩阵的特征值和特征向量。
eigvals(a)
计算一般矩阵的特征值。
eigvalsh(a[, UPLO])
计算 Hermitian 矩阵的特征值。
inv(a)
返回方阵的逆。
lstsq(a, b[, rcond, numpy_resid])
返回线性方程组的最小二乘解。
matmul(x1, x2, /, *[, precision, ...])
执行矩阵乘法。
matrix_norm(x, /, *[, keepdims, ord])
计算矩阵或矩阵堆栈的范数。
matrix_power(a, n)
将方阵提升到整数幂。
matrix_rank(M[, rtol, tol])
计算矩阵的秩。
matrix_transpose(x, /)
转置矩阵或矩阵堆栈。
multi_dot(arrays, *[, precision])
高效计算数组序列之间的矩阵乘积。
norm(x[, ord, axis, keepdims])
计算矩阵或向量的范数。
outer(x1, x2, /)
计算两个一维数组的外积。
pinv(a[, rtol, hermitian, rcond])
计算(Moore-Penrose)伪逆。
qr()
计算数组的 QR 分解。
slogdet(a, *[, method])
计算数组行列式的符号和(自然)对数。
solve(a, b)
解线性方程组。
svd()
计算奇异值分解。
svdvals(x, /)
计算矩阵的奇异值。
tensordot(x1, x2, /, *[, axes, precision, ...])
计算两个 N 维数组的张量点积。
tensorinv(a[, ind])
计算数组的张量逆。
tensorsolve(a, b[, axes])
解张量方程 a x = b 以得到 x。
trace(x, /, *[, offset, dtype])
计算矩阵的迹。
vector_norm(x, /, *[, axis, keepdims, ord])
计算向量或向量批次的范数。
vecdot(x1, x2, /, *[, axis, precision, ...])
计算(批量)向量共轭点积。
JAX Array
JAX Array(以及其别名 jax.numpy.ndarray)是 JAX 中的核心数组对象:您可以将其视为 JAX 中与numpy.ndarray 等效的对象。与 numpy.ndarray 一样,大多数用户不需要手动实例化 Array 对象,而是通过 jax.numpy 函数如 array()、arange()、linspace() 和上面列出的其他函数来创建它们。
复制和序列化
JAX Array对象设计为在适当的情况下与 Python 标准库工具无缝配合。
使用内置copy模块时,当copy.copy()或copy.deepcopy()遇到Array时,等效于调用copy()方法,该方法将在与原始数组相同设备上创建缓冲区的副本。在追踪/JIT 编译的代码中,这将正确工作,尽管在此上下文中,复制操作可能会被编译器省略。
当内置pickle模块遇到Array时,它将通过紧凑的位表示方式对其进行序列化,类似于对numpy.ndarray对象的处理。解封后,结果将是一个新的Array对象在默认设备上。这是因为通常情况下,pickling 和 unpickling 可能发生在不同的运行环境中,并且没有通用的方法将一个运行时环境的设备 ID 映射到另一个的设备 ID。如果在追踪/JIT 编译的代码中使用pickle,将导致ConcretizationTypeError。
jax.numpy.fft.fft
原文:jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fft.fft.html
jax.numpy.fft.fft(a, n=None, axis=-1, norm=None)
计算一维离散傅里叶变换。
numpy.fft.fft() 的 LAX-backend 实现。
下面是原始文档字符串。
此函数使用高效的快速傅里叶变换(FFT)算法计算一维 n-点离散傅里叶变换(DFT)[CT]。
参数:
a (array_like) – 输入数组,可以是复数。
n (int, optional) – 输出的变换轴的长度。如果 n 小于输入的长度,则会截取输入。如果 n 较大,则在末尾用零填充输入。如果未提供 n,则使用由 axis 指定的轴上的输入长度。
axis (int, optional) – 计算 FFT 的轴。如果未给出,则使用最后一个轴。
norm ({"backward"**, "ortho"**, "forward"}**, optional) – 规范化方式,可选。
返回值:
out – 截断或零填充的输入,沿由 axis 指示的轴进行变换,如果未指定 axis,则为最后一个轴。
返回类型:
复数 ndarray
参考文献
[CT]
Cooley, James W., and John W. Tukey, 1965, “An algorithm for the machine calculation of complex Fourier series,” Math. Comput. 19: 297-301.
jax.numpy.fft.fft2
原文:jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fft.fft2.html
jax.numpy.fft.fft2(a, s=None, axes=(-2, -1), norm=None)
计算二维离散傅立叶变换。
numpy.fft.fft2()的 LAX 后端实现。
以下是原始文档字符串。
此函数通过快速傅立叶变换(FFT)计算M维数组中的任何轴上的n维离散傅立叶变换。默认情况下,变换计算输入数组的最后两个轴上的变换,即二维 FFT。
参数:
a(array_like) – 输入数组,可以是复数
s(整数序列,可选) –
输出的形状(每个转换轴的长度)(s[0]指代轴 0,s[1]指代轴 1 等)。这对应于fft(x, n)中的n。沿着每个轴,如果给定的形状比输入小,则截断输入。如果大,则用零填充输入。
自 2.0 版更改:如果为-1,则使用整个输入(无填充/修剪)。
如果未提供s,则使用指定轴上输入的形状。
自 2.0 版起已弃用:如果s不是None,则axes也不能是None。
自 2.0 版起已弃用:s必须仅包含int,而不是None值。当前None值意味着在相应的一维变换中使用n的默认值,但此行为已弃用。
axes(整数序列,可选) –
计算 FFT 的轴。如果未给出,则使用最后两个轴。轴中的重复索引表示在该轴上执行多次变换。单元素序列表示执行一维 FFT。默认值:(-2, -1)。
自 2.0 版起已弃用:如果指定了s,则要转换的相应轴不能为None。
norm({"backward","ortho","forward"},可选)
返回:
out – 通过指定的轴变换的截断或零填充输入,或者如果未给出axes,则为最后两个轴。
返回类型:
复数ndarray
jax.numpy.fft.fftfreq
原文:jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fft.fftfreq.html
jax.numpy.fft.fftfreq(n, d=1.0, *, dtype=None)
返回离散傅立叶变换的采样频率。
LAX 后端实现的numpy.fft.fftfreq()。
以下是原始文档字符串。
返回的浮点数数组 f 包含以每个采样间距单位的频率单元为周期的频率箱中心(从起始点开始为零)。例如,如果采样间距以秒为单位,则频率单位为每秒循环数。
给定窗口长度 n 和采样间距 d:
f = [0, 1, ..., n/2-1, -n/2, ..., -1] / (d*n) if n is even
f = [0, 1, ..., (n-1)/2, -(n-1)/2, ..., -1] / (d*n) if n is odd
参数:
n(int)– 窗口长度。
d(标量,可选)– 采样间距(采样率的倒数)。默认为 1。
dtype(可选)– 返回频率的数据类型。如果未指定,将使用 JAX 的默认浮点数数据类型。
返回值:
f – 长度为 n 的包含采样频率的数组。
返回类型:
ndarray
jax.numpy.fft.fftn
原文:jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fft.fftn.html
jax.numpy.fft.fftn(a, s=None, axes=None, norm=None)
计算 N 维离散傅里叶变换。
numpy.fft.fftn() 的 LAX 后端实现。
原始文档字符串如下。
该函数通过快速傅里叶变换(FFT)在 M 维数组中的任意数量的轴上计算 N 维离散傅里叶变换。
参数:
a(array_like) – 输入数组,可以是复数。
s(整数序列,可选) –
输出的各个转换轴的形状(s[0] 指代轴 0,s[1] 指代轴 1,等等)。这对应于 fft(x, n) 中的 n。沿任何轴,如果给定的形状比输入的小,则输入会被裁剪。如果形状比输入大,则输入将用零填充。
在版本 2.0 中更改:如果是 -1,则使用整个输入(无填充/修剪)。
如果未给出 s,则沿 axes 指定的轴使用输入的形状。
从版本 2.0 开始弃用:如果 s 不是 None,则轴也不能是 None。
从版本 2.0 开始弃用:s 必须仅包含 int 值,而不能是 None 值。当前 None 值意味着在相应的 1-D 变换中使用默认值 n,但此行为已弃用。
axes(整数序列,可选) –
要计算 FFT 的轴。如果未给出,则使用最后 len(s) 个轴,或者如果 s 也未指定,则使用所有轴。在 axes 中重复的索引意味着该轴上的变换执行多次。
从版本 2.0 开始弃用:如果指定了 s,则必须显式指定要转换的对应轴。
norm({"backward","ortho","forward"},可选)
返回:
out – 被截断或零填充的输入,在由 axes 指示的轴上进行转换,或者根据上述参数部分中的 s 和 a 的组合。
返回类型:
复数 ndarray
jax.numpy.fft.fftshift
原文:jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fft.fftshift.html
jax.numpy.fft.fftshift(x, axes=None)
将零频率分量移动到频谱中心。
LAX 后端实现的 numpy.fft.fftshift()。
以下是原始文档字符串。
此函数对列出的所有轴交换了半空间(默认为所有轴)。注意,只有当 len(x) 为偶数时,y[0] 才是奈奎斯特分量。
参数:
x(array_like) – 输入数组。
axes(int 或 形状元组,可选) – 要进行移位的轴。默认为 None,即移动所有轴。
返回值:
y – 移位后的数组。
返回类型:
ndarray
jax.numpy.fft.hfft
原文:jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fft.hfft.html
jax.numpy.fft.hfft(a, n=None, axis=-1, norm=None)
计算具有 Hermitian 对称性(即实数
numpy.fft.hfft()的 LAX 后端实现。
下面是原始文档字符串。
谱。
参数:
a(array_like) – 输入数组。
n(int,可选) – 输出的转换轴的长度。对于 n 个输出点,需要n//2 + 1个输入点。如果输入比这个长,则裁剪。如果输入比这个短,则用零填充。如果未提供 n,则取为2*(m-1),其中 m 是由轴指定的输入的长度。
axis(int,可选) – 计算 FFT 的轴。如果未指定,则使用最后一个轴。
norm({"backward"**, "ortho"**, *"forward"}**,可选)
返回:
out – 被截断或用零填充的输入,在由 axis 指示的轴上变换,如果未指定 axis,则在最后一个轴上变换。转换轴的长度为 n,如果未提供 n,则为2*m - 2,其中 m 是输入的转换轴的长度。为了得到奇数个输出点,必须指定 n,例如在典型情况下为2*m - 1,
返回类型:
ndarray
jax.numpy.fft.ifft
原文:jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fft.ifft.html
jax.numpy.fft.ifft(a, n=None, axis=-1, norm=None)
计算一维逆离散傅里叶变换。
numpy.fft.ifft() 的 LAX 后端实现。
下面是原始文档字符串。
此函数计算由 fft 计算的一维 n 点离散傅里叶变换的逆变换。换句话说,ifft(fft(a)) == a,在数值精度范围内成立。有关算法和定义的一般描述,请参阅 numpy.fft。
输入应按 fft 返回的方式排序,即,
a[0] 应包含零频率项,
a[1:n//2] 应包含正频率项,
a[n//2 + 1:] 应包含负频率项,按最负频率开始的递增顺序排列。
对于偶数个输入点,A[n//2] 表示正和负奈奎斯特频率值的总和,因为这两者被混合在一起。有关详细信息,请参阅 numpy.fft。
Parameters:
a (array_like) – 输入数组,可以是复数。
n (int, 可选) – 输出的转换轴的长度。如果 n 小于输入的长度,则对输入进行裁剪。如果大于输入,则用零填充。如果未给出 n,则使用由 axis 指定的轴的输入长度。有关填充问题的注释,请参阅注释。
axis (int, 可选) – 计算逆离散傅里叶变换的轴。如果未给出,则使用最后一个轴。
norm ({"backward"**, "ortho"**, "forward"}**, 可选)
Returns:
out – 沿由 axis 指定的轴变换后的截断或零填充输入,或者如果未指定 axis,则为最后一个轴。
Return type:
复数 ndarray
jax.numpy.fft.ifft2
原文:jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fft.ifft2.html
jax.numpy.fft.ifft2(a, s=None, axes=(-2, -1), norm=None)
计算二维逆离散傅里叶变换。
LAX 后端实现的 numpy.fft.ifft2()。
下面是原始的文档字符串。
此函数通过快速傅里叶变换(FFT)在 M 维数组中的任意数量的轴上计算二维离散傅里叶逆变换。换句话说,ifft2(fft2(a)) == a,在数值精度内成立。默认情况下,计算逆变换是在输入数组的最后两个轴上进行的。
输入的顺序与 fft2 返回的顺序相同,即应该在两个轴的低阶角落中有零频率项,这两个轴的第一半中有正频率项,中间有奈奎斯特频率项,并且两个轴的后半部分中有负频率项,按照递减负频率的顺序。
参数:
a (类似数组) – 输入数组,可以是复数。
s (整数序列, 可选) –
输出的形状(每个轴的长度)(s[0] 对应轴 0,s[1] 对应轴 1,依此类推)。这对应于 ifft(x, n) 的 n。沿每个轴,如果给定形状比输入小,则对输入进行裁剪。如果形状更大,则用零填充输入。
自版本 2.0 起已更改:如果为 -1,则使用整个输入(无填充/修剪)。
如果未给出 s,则使用由 axes 指定的轴上的输入形状。有关 ifft 零填充问题的问题,请参见注释。
自版本 2.0 起已废弃:若 s 不为 None,则 axes 也不能为 None。
自版本 2.0 起已废弃:s 必须只包含 int 值,不能包含 None 值。目前 None 值意味着在对应的一维变换中使用默认值 n,但此行为已被弃用。
axes (整数序列, 可选) –
用于计算 FFT 的轴。如果未指定,则使用最后两个轴。在 axes 中重复的索引表示对该轴执行多次变换。一个元素的序列表示执行一维 FFT。默认值:(-2, -1)。
自版本 2.0 起已废弃:若指定了 s,则要转换的相应轴不能为 None。
norm ({"backward", "ortho", "forward"}, 可选)
返回:
out – 在由 axes 指示的轴上变换的截断或零填充输入,或如果未给出 axes,则在最后两个轴上变换。
返回类型:
复数 ndarray
jax.numpy.fft.ifftn
jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fft.ifftn.html
jax.numpy.fft.ifftn(a, s=None, axes=None, norm=None)
计算 N 维逆离散傅立叶变换。
[numpy.fft.ifftn()的 LAX 后端实现](https://numpy.org/doc/stable/reference/generated/numpy.fft.ifftn.html#numpy.fft.ifftn ("在 NumPy v2.0 中"))。
以下是原始文档字符串。
该函数通过快速傅里叶变换(FFT)在 M 维数组中的任意数量的轴上,计算 N 维福里叶变换的逆。换句话说,ifftn(fftn(a)) == a在数值精度内成立。有关使用的定义和约定的描述,请参见 numpy.fft。
输入与 ifft 类似,应以与 fftn 返回的方式相同的顺序排序,即应在低阶角落中具有所有轴的零频率项,在所有轴的前半部分具有正频率项,在所有轴的中间具有奈奎斯特频率项,并且在所有轴的后半部分具有负频率项,按照递减负频率的顺序排列。
参数:
a (array_like) – 输入数组,可以是复数。
s (整数的序列,可选) –
输出的形状(每个转换轴的长度)(s[0]指轴 0,s[1]指轴 1,以此类推)。这对应于ifft(x, n)的n。沿任何轴,如果给定的形状小于输入的形状,则会对输入进行裁剪。如果大于输入,则用零填充输入。
在版本 2.0 中更改:如果为-1,则使用整个输入(无填充/修剪)。
如果未给出s,则使用由 axes 指定的轴的输入形状。参见关于 ifft 零填充问题的注释。
从版本 2.0 开始已弃用:如果s不是None,则轴也不能是None。
从版本 2.0 开始已弃用:s必须只包含int,而不是None值。None值当前表示在相应的 1-D 变换中使用n的默认值,但此行为已弃用。
axes (整数的序列,可选) –
计算逆离散傅里叶变换的轴。如果未给出,则使用最后的len(s)轴,或者如果也未指定s,则使用所有轴。轴中的重复索引意味着在该轴上执行多次逆变换。
从版本 2.0 开始已弃用:如果指定了s,则必须明确指定要转换的相应轴。
norm ({"backward"**, "ortho"**, "forward"}**, 可选)
返回:
out – 截断或用零填充的输入,沿着由 axes 指示的轴,或由上面参数节中解释的 s 或 a 的组合。
返回类型:
复数的 ndarray
jax.numpy.fft.ifftshift
原文:jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fft.ifftshift.html
jax.numpy.fft.ifftshift(x, axes=None)
fftshift 的反操作。对于偶数长度的 x,它们是相同的。
LAX 后端实现的 numpy.fft.ifftshift()。
以下是原始文档字符串。
函数对于奇数长度的 x 会有一个样本的差异。
参数:
x (array_like) – 输入数组。
axes (int 或 形状元组**, 可选) – 用于计算的轴。默认为 None,即对所有轴进行移位。
返回值:
y – 移位后的数组。
返回类型:
ndarray
jax.numpy.fft.ihfft
原文:jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fft.ihfft.html
jax.numpy.fft.ihfft(a, n=None, axis=-1, norm=None)
计算具有 Hermitian 对称性的信号的逆 FFT。
LAX 后端实现的numpy.fft.ihfft()。
以下是原始文档字符串。
参数:
a (array_like) – 输入数组。
n (int, optional) – 逆 FFT 的长度,即用于输入的变换轴上的点数。如果 n 小于输入的长度,则输入被截断。如果大于输入,则用零填充。如果未给出 n,则使用由轴指定的输入的长度。
axis (int, optional) – 计算逆 FFT 的轴。如果未给出,则使用最后一个轴。
norm ({"backward"**, "ortho"**, "forward"}**, optional)
返回:
out – 截断或零填充的输入,在指定的轴上进行变换,如果未指定轴,则为最后一个轴。变换后的轴的长度为n//2 + 1。
返回类型:
复数 ndarray
jax.numpy.fft.irfft
原文:jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fft.irfft.html
jax.numpy.fft.irfft(a, n=None, axis=-1, norm=None)
计算 rfft 的逆操作。
numpy.fft.irfft() 的 LAX 后端实现。
以下为原始文档字符串。
此函数计算由 rfft 计算的实输入的一维 n 点离散傅立叶变换的逆变换。换句话说,irfft(rfft(a), len(a)) == a 在数值精度内成立。(有关为何在这里需要 len(a) 的详细信息,请参阅下面的注释。)
输入应该是由 rfft 返回的形式,即实部的零频率项,后跟复数正频率项,按频率递增的顺序排列。由于实输入的离散傅立叶变换是共轭对称的,负频率项被视为对应正频率项的复共轭。
参数:
a (array_like) – 输入数组。
n (int, optional) – 输出的转换轴的长度。对于 n 个输出点,需要 n//2+1 个输入点。如果输入长于此,它将被截断。如果输入短于此,则用零填充。如果未给出 n,则取 2*(m-1),其中 m 是由轴指定的输入的长度。
axis (int, optional) – 计算逆 FFT 的轴。如果未给出,则使用最后一个轴。
norm ({"backward"**, "ortho"**, "forward"}**, optional)
返回:
out – 被截断或零填充的输入,沿着指定的轴变换,如果未指定轴,则沿最后一个轴。转换后的轴的长度为 n,或者如果未给出 n,则为 2*(m-1),其中 m 是输入的转换轴的长度。要获得奇数个输出点,必须指定 n。
返回类型:
ndarray
jax.numpy.fft.irfft2
原文:jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fft.irfft2.html
jax.numpy.fft.irfft2(a, s=None, axes=(-2, -1), norm=None)
计算 rfft2 的逆。
numpy.fft.irfft2() 的 LAX-backend 实现。
以下为原始文档注释。
参数:
a (array_like) – 输入数组
s (ints 序列,可选) –
逆 FFT 输出的形状。
自 2.0 版本更改:如果为 -1,则使用整个输入(无填充/修剪)。
自 2.0 版本弃用:如果 s 不为 None,则轴也不能为 None。
自 2.0 版本弃用:s 必须仅包含 int 值,而不是 None 值。当前的 None 值意味着在相应的 1-D 变换中使用 n 的默认值,但此行为已弃用。
axes (ints 序列,可选) –
要计算逆 fft 的轴。默认:(-2, -1),即最后两个轴。
自 2.0 版本弃用:如果指定了 s,则要转换的相应轴不能为 None。
norm ({"backward"**, "ortho"**, "forward"}**, 可选)
返回:
out – 逆实 2-D FFT 的结果。
返回类型:
ndarray
jax.numpy.fft.irfftn
原文:jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fft.irfftn.html
jax.numpy.fft.irfftn(a, s=None, axes=None, norm=None)
计算 rfftn 的逆。
numpy.fft.irfftn()的 LAX 后端实现.
以下是原始文档字符串。
此函数通过快速傅里叶变换(FFT)计算 N 维实输入的逆离散傅里叶变换,涵盖 M 维数组中的任意数量轴。换句话说,irfftn(rfftn(a), a.shape)在数值精度范围内等于a。(a.shape对于 irfft 是必要的,就像对于 irfft 一样,出于同样的原因。)
输入应按与由 rfftn 返回的相同方式排序,即对于最终变换轴的 irfft,以及对于所有其他轴的 ifftn。
参数:
a(类似数组) – 输入数组。
s(整数序列,可选的) –
输出的形状(每个转换轴的长度)(s[0]指轴 0,s[1]指轴 1 等)。s也是沿此轴使用的输入点数,除了最后一个轴,输入的点数为s[-1]//2+1。沿任何轴,如果s指示的形状比输入小,则输入被裁剪。如果更大,则用零填充输入。
自版本 2.0 更改:如果为-1,则使用整个输入(无填充/修剪)。
如果未给出s,则沿着由axes指定的轴使用输入的形状。除了最后一个轴被视为2*(m-1),其中m是沿该轴的输入长度。
自版本 2.0 起不推荐使用:如果s不为None,则axes也不得为None。
自版本 2.0 起不推荐使用:s必须只包含整数,而不能包含None值。目前None值意味着在相应的 1-D 变换中使用默认值n,但此行为已弃用。
axes(整数序列,可选的) –
要计算逆 FFT 的轴。如果未给出,则使用最后的len(s)个轴,或者如果也未指定s,则使用所有轴。在axes中重复的索引意味着在该轴上执行多次逆变换。
自版本 2.0 起不推荐使用:如果指定了s,则必须显式指定要转换的相应轴。
norm({"backward",* "ortho",* "forward"},可选的)
返回:
out – 经过轴指示的变换,截断或填充零的输入,或者通过参数部分上述的 s 或 a 的组合进行变换。每个转换后轴的长度由相应的 s 的元素给出,或者如果未给出 s,则在除最后一个轴外的每个轴上都是输入的长度。当未给出 s 时,最终变换轴上的输出长度为 2*(m-1),其中 m 是输入的最终变换轴的长度。要在最终轴上得到奇数个输出点,必须指定 s。
Return type:
ndarray
jax.numpy.fft.rfft
原文:jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fft.rfft.html
jax.numpy.fft.rfft(a, n=None, axis=-1, norm=None)
计算实数输入的一维离散傅里叶变换。
numpy.fft.rfft()的 LAX 后端实现。
下面是原始文档字符串。
此函数通过一种称为快速傅里叶变换(FFT)的高效算法计算实值数组的一维n点离散傅里叶变换(DFT)。
Parameters:
a (array_like) – 输入数组
n (int, 可选) – 输入中变换轴上要使用的点数。如果 n 小于输入的长度,则截取输入。如果 n 大于输入长度,则用零填充输入。如果未给出 n,则使用由 axis 指定的轴上的输入长度。
axis (int, 可选) – 执行 FFT 的轴。如果未给出,则使用最后一个轴。
norm ({"backward"**, "ortho"**, "forward"}**, 可选)
Returns:
out – 截断或零填充的输入,沿 axis 指示的轴变换,如果未指定 axis,则为最后一个轴。如果 n 是偶数,则变换轴的长度为(n/2)+1。如果 n 是奇数,则长度为(n+1)/2。
Return type:
复数的 ndarray
jax.numpy.fft.rfft2
原文:jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fft.rfft2.html
jax.numpy.fft.rfft2(a, s=None, axes=(-2, -1), norm=None)
计算实数组的二维 FFT。
numpy.fft.rfft2() 的 LAX 后端实现。
下面是原始文档字符串。
参数:
a(数组) – 输入数组,被视为实数。
s(整数序列,可选) –
FFT 的形状。
2.0 版本中更改:如果是 -1,则使用整个输入(无填充/修剪)。
自 2.0 版本起弃用:如果 s 不是 None,则轴也不能是 None。
自 2.0 版本起弃用:s 必须仅包含 int 值,而非 None 值。目前 None 值意味着在相应的一维变换中使用 n 的默认值,但此行为已弃用。
axes(整数序列,可选) –
要计算 FFT 的轴。默认值:(-2, -1)。
自 2.0 版本起弃用:如果指定了 s,则要转换的相应轴不能为 None。
norm({"backward"**, "ortho"**, "forward"},可选)
返回:
out – 实数 2-D FFT 的结果。
返回类型:
ndarray
jax.numpy.fft.rfftfreq
原文:jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fft.rfftfreq.html
jax.numpy.fft.rfftfreq(n, d=1.0, *, dtype=None)
返回离散傅里叶变换样本频率。
numpy.fft.rfftfreq() 的 LAX 后端实现。
以下是原始文档字符串。
(用于 rfft, irfft)。
返回的浮点数组 f 包含以每个采样间隔为单位的频率箱中心(从起始处为零)。例如,如果采样间隔以秒为单位,则频率单位为每秒循环数。
给定窗口长度 n 和采样间隔 d:
f = [0, 1, ..., n/2-1, n/2] / (d*n) if n is even
f = [0, 1, ..., (n-1)/2-1, (n-1)/2] / (d*n) if n is odd
与 fftfreq 不同(但类似于 scipy.fftpack.rfftfreq),将奈奎斯特频率分量视为正值。
参数:
n (int – 窗口长度。
d (标量, 可选) – 采样间隔(采样率的倒数)。默认为 1。
dtype (可选) – 返回频率的数据类型。如果未指定,则使用 JAX 的默认浮点数数据类型。
返回:
f – 长度为 n//2 + 1 的数组,包含采样频率。
返回类型:
ndarray
jax.numpy.fft.rfftn
原文:jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fft.rfftn.html
jax.numpy.fft.rfftn(a, s=None, axes=None, norm=None)
计算实输入的 N 维离散傅里叶变换。
numpy.fft.rfftn() 的 LAX 后端实现.
以下为原始文档字符串。
此函数通过快速傅里叶变换(FFT)对 M 维实数组中的任意数量轴执行 N 维离散傅里叶变换。默认情况下,所有轴都被转换,实变换在最后一个轴上执行,而其余变换是复数。
参数:
a(类数组) - 输入数组,假定为实数。
s(整数序列,可选) -
要使用的输入的每个转换轴上的形状(长度)。(s[0] 是指轴 0,s[1] 是指轴 1,依此类推)。对于 rfft(x, n),s 的最后一个元素对应于 n,而对于其余轴,它对应于 fft(x, n) 的 n。沿着任何轴,如果给定的形状小于输入的形状,则输入被裁剪。如果它更大,则输入被填充为零。
版本 2.0 中的更改:如果为-1,则使用整个输入(无填充/修剪)。
如果未给出 s,则使用由轴指定的输入的形状。
自版本 2.0 起弃用:如果 s 不是 None,则轴也不能是 None。
自版本 2.0 起弃用:s 必须仅包含整数,不能是 None 值。目前 None 值意味着对应 1-D 变换中 n 的默认值,但此行为已弃用。
axes(整数序列,可选) -
用于计算 FFT 的轴。如果未给出,则使用最后 len(s) 个轴,如果也未指定 s,则使用所有轴。
自版本 2.0 起弃用:如果指定了 s,则必须显式指定要转换的对应轴。
norm({"backward", "ortho", "forward"},可选)
返回:
out – 通过轴指示的轴或由参数部分上述 s 和 a 的组合截断或零填充的输入,最后一个转换轴的长度将为 s[-1]//2+1,而其余转换的轴将具有根据 s 或保持与输入不变的长度。
返回类型:
复数 ndarray