谷歌最近开源了Metrax,这是一个JAX库,提供标准化和高性能的指标实现,适用于分类、回归、自然语言处理、视觉和音频模型。
Metrax填补了JAX生态系统中的一个空白。谷歌解释说,这个空白迫使许多从TensorFlow迁移到JAX的团队自己实现常见的评估指标,如准确率、F1、均方根误差等。
虽然创建指标对某些人来说似乎是一个相当简单和直接的任务,但在考虑到数据中心规模的分布式计算环境中的大规模训练和评估时,这就变得不那么简单了。
Metrax为一系列机器学习模型提供了预定义的评估指标,包括分类、回归、推荐、视觉和音频,特别支持分布式和大规模训练环境。对于视觉模型,该库包括交并比(IoU)、信噪比(SNR)和结构相似性指数(SSIM)等指标。Metrax还包括强大的与自然语言处理相关的指标,如困惑度、BLEU和ROUGE。
谷歌指出,Metrax的目标之一是确保所有指标都得到良好的实现并遵循最佳实践。在指标定义支持的情况下,Metrax使用了先进的JAX功能,如vmap和jit来提升性能。例如,这些功能用于新“at K”指标的实现,以便能够并行计算多个K值。这使得可以更全面和更快速地评估模型。
您可以使用
PrecisionAtK来确定模型在多个K值(例如,K=1,K=8和K=20)下的精度,所有这些都在一次通过模型的过程中完成,而不需要为每个参数多次调用PrecisionAtK。
DevOps工程师在Substack上以Neural Foundry的名义写道:
Metrax支持在一次通过中计算多个K值的事实对排名系统来说是一个巨大的胜利。每次我切换项目时都要重写指标工具,这种标准化早该实现了。API看起来也很简洁。好奇他们是否针对特定用例(如大规模推荐管道)与自定义实现进行了基准测试。
以下代码片段展示了如何计算精度指标给定预测和标签。可以指定一个可选的阈值,将概率预测转换为二进制预测:
import metrax
# Directly compute the metric state.
metric_state = metrax.Precision.from_model_output(
predictions=predictions,
labels=labels,
threshold=0.5
)
# The result is then readily available by calling compute().
result = metric_state.compute()
result谷歌还发布了一个笔记本,包含一整套示例,包括多设备扩展和与Flax NNX的集成。Flax NNX是一个简化的API,使得在JAX中创建、检查、调试和分析神经网络变得更容易。
