Deep Java Library (DJL) 入门指南
[LOADING...]
1. 概述
在本教程中,我们将了解 Deep Java Library (DJL),这是由 AWS 开发的与引擎无关的 机器学习 框架。
诸如 PyTorch、TensorFlow、MXNet 和 ONNX 等 Python 库在开发和执行深度学习神经网络方面处于领先地位。因此,Java 开发人员在处理使用人工智能的应用程序时面临困难。
DJL 通过提供一个与各种 ML 引擎交互的单一接口,实现了对这些引擎的抽象。 在本文中,我们将使用 DJL 构建一个简单的程序,用于从图像中识别手写数字。尽管 DJL 支持模型训练和推理,但我们将重点关注推理,并从 DJL 的公共 Model Zoo 仓库加载一个预训练的图像分类模型。
2. 核心概念
DJL API 提供了一个用于与机器学习引擎协作的标准框架:
使用 ML Python 库的应用程序必须依赖特定于引擎的 API。这会产生紧密耦合,增加复杂性和维护开销。相比之下,DJL 是一个依赖项极少的轻量级库,它使应用程序能够透明地与底层的 ML 引擎交互。 从本质上讲,DJL 并不取代 ML 引擎,而是提供一个统一的 Java API,在运行时将执行委托给选定的引擎。因此,在这些库之间切换是无缝的,只需极少的工作量。
此外,DJL 可以访问预训练模型的集中注册表,称为 Model Zoo。这些模型适用于常见的用例,例如图像识别、自然语言处理和词到向量转换。
该库可以加载这些模型,然后使用数据集调用它们以生成所需的输出:
该库从 Model Zoo 中发现一个针对 ML 引擎和特定用途量身定制的模型。 然后,它将数据集预处理为底层引擎可以理解的格式。该库使用转换后的数据集调用底层 ML 引擎并获取输出。接着,它将输出转换为 Java 程序可以理解的格式。最后,释放并清理所有资源。
在接下来的章节中,我们将了解先决条件依赖项、重要的库组件以及我们将实现的一个用例。
3. 关键 Java 组件
让我们探索 DJL 的关键 Java 组件:
Criteria.Builder 类有助于定义 ML 模型搜索标准,例如模型名称、模型输入和输出参数以及模型应用,然后创建一个 Criteria 对象。Criteria#loadModel() 方法随后直接从 Model Zoo 加载 ML 模型。
Model#load() 接口方法允许应用程序从本地缓存加载 ML 模型。根据用例(例如训练模型或使用模型进行预测),应用程序可以使用 newTrainer() 或 newPredictor() 方法来创建 Trainer 或 Predictor 对象。 Predictor#predict() 方法根据给定的输入(例如图像、音频转录或文本)预测输出。此外,Predictor#batchPredict() 提供了处理多个输入并生成相应输出列表的灵活性。
Predictor 依赖于 Translator 将输入对象转换为底层 ML 引擎可以理解的格式。此外,Translator 将 ML 引擎的输出转换为应用程序可以理解的格式。例如,它将图像或音频文件对象转换为 ML 引擎可以处理的 n 维表示。DJL 库提供了 Translator 接口的几个内置实现,例如 ImageClassificationTranslator、SpeechRecognitionTranslator 和 ObjectDetectionTranslator。在更专业的场景中,开发人员还可以实现 Translator 接口来处理自定义的预处理和后处理逻辑。
这些概念在接下来的章节中将变得更加清晰,届时我们将实现一个图像识别用例,以识别图像中的单个手写数字。
4. 先决条件
首先,我们从 Maven 导入 DJL Bill of Materials (BOM),以确保所有 DJL 依赖项的版本一致:
接下来,我们添加 DJL Model Zoo 模块,它提供对托管在公共仓库中的预训练模型的访问:
最后,DJL 需要一个特定于运行时引擎的依赖项来执行模型:
我们包含 pytorch-engine 库,以使用 PyTorch 作为底层的 ML 引擎。
最后,对于图像识别用例,我们将使用在经典 MNIST 数据集(28×28 手写数字 0-9)上预训练的 PyTorch 模型。
5. 图像识别用例实现
现在,我们可以实现图像识别用例,以识别图像中的单个手写数字。
首先,让我们定义一个 DigitIdentifier 类:
在该类中,identifyDigit() 方法首先使用 Criteria 对象加载在 MNIST 数据集上训练的计算机视觉模型。然后,我们调用 model#newPredictor() 来获取 Predictor 对象。接下来,我们将数字图像的路径传递给 Predictor#predict() 方法以获取 Classifications 对象。Classifications 对象由多个 Classification 对象组成,这些对象基本上代表带有准确性分数的预测。此外,我们不需要遍历所有 Classification 对象,而是通过调用 Classifications#best() 方法来选择最佳结果。
现在,让我们看一组从 MNIST 测试数据集中提取的数字 3 的手写图像: [LOADING...]
我们将运行 identifyDigit() 方法,看看它能否预测测试图像数据集中的数字 3:
该参数化 JUnit 测试方法使用测试图像的路径调用 DigitIdentifier#identifyDigit() 方法。我们发现 ML 模型正确预测了图像文件中的数字。
机器学习模型进行概率性预测,因此即使对于相似的图像,它们也并不总是产生正确的结果。 训练数据集的质量和代表性在很大程度上决定了预测准确性。
6. 结论
在本文中,我们学习了 DJL API 的关键组件并实现了一个图像识别用例。这可以作为我们自行探索更多功能和用例的垫脚石。
DJL 抽象底层 ML 引擎的能力确实可以帮助 Java 开发人员为需要 ML 的应用程序做出贡献。 然而,作为先决条件,理解机器学习概念同样重要,并有助于正确采用。
像往常一样,本文中使用的源代码可以在 GitHub 上获取。 本文 A Guide to Deep Java Library 首次出现在 Baeldung 上。