Ohhnews

分类导航

$ cd ..
Baeldung原文

Deep Java Library (DJL) 入门指南

#djl#java#机器学习#aws#图像识别

[LOADING...]

1. 概述

在本教程中,我们将了解 Deep Java Library (DJL),这是由 AWS 开发的与引擎无关的 机器学习 框架。

诸如 PyTorch、TensorFlow、MXNet 和 ONNX 等 Python 库在开发和执行深度学习神经网络方面处于领先地位。因此,Java 开发人员在处理使用人工智能的应用程序时面临困难。

DJL 通过提供一个与各种 ML 引擎交互的单一接口,实现了对这些引擎的抽象。 在本文中,我们将使用 DJL 构建一个简单的程序,用于从图像中识别手写数字。尽管 DJL 支持模型训练和推理,但我们将重点关注推理,并从 DJL 的公共 Model Zoo 仓库加载一个预训练的图像分类模型。

2. 核心概念

DJL API 提供了一个用于与机器学习引擎协作的标准框架:

[LOADING...]

使用 ML Python 库的应用程序必须依赖特定于引擎的 API。这会产生紧密耦合,增加复杂性和维护开销。相比之下,DJL 是一个依赖项极少的轻量级库,它使应用程序能够透明地与底层的 ML 引擎交互。 从本质上讲,DJL 并不取代 ML 引擎,而是提供一个统一的 Java API,在运行时将执行委托给选定的引擎。因此,在这些库之间切换是无缝的,只需极少的工作量。

此外,DJL 可以访问预训练模型的集中注册表,称为 Model Zoo。这些模型适用于常见的用例,例如图像识别、自然语言处理和词到向量转换。

该库可以加载这些模型,然后使用数据集调用它们以生成所需的输出:

[LOADING...]

该库从 Model Zoo 中发现一个针对 ML 引擎和特定用途量身定制的模型。 然后,它将数据集预处理为底层引擎可以理解的格式。该库使用转换后的数据集调用底层 ML 引擎并获取输出。接着,它将输出转换为 Java 程序可以理解的格式。最后,释放并清理所有资源。

在接下来的章节中,我们将了解先决条件依赖项、重要的库组件以及我们将实现的一个用例。

3. 关键 Java 组件

让我们探索 DJL 的关键 Java 组件:

[LOADING...]

Criteria.Builder 类有助于定义 ML 模型搜索标准,例如模型名称、模型输入和输出参数以及模型应用,然后创建一个 Criteria 对象。Criteria#loadModel() 方法随后直接从 Model Zoo 加载 ML 模型。

Model#load() 接口方法允许应用程序从本地缓存加载 ML 模型。根据用例(例如训练模型或使用模型进行预测),应用程序可以使用 newTrainer()newPredictor() 方法来创建 TrainerPredictor 对象。 Predictor#predict() 方法根据给定的输入(例如图像、音频转录或文本)预测输出。此外,Predictor#batchPredict() 提供了处理多个输入并生成相应输出列表的灵活性。

Predictor 依赖于 Translator 将输入对象转换为底层 ML 引擎可以理解的格式。此外,Translator 将 ML 引擎的输出转换为应用程序可以理解的格式。例如,它将图像或音频文件对象转换为 ML 引擎可以处理的 n 维表示。DJL 库提供了 Translator 接口的几个内置实现,例如 ImageClassificationTranslatorSpeechRecognitionTranslatorObjectDetectionTranslator。在更专业的场景中,开发人员还可以实现 Translator 接口来处理自定义的预处理和后处理逻辑。

这些概念在接下来的章节中将变得更加清晰,届时我们将实现一个图像识别用例,以识别图像中的单个手写数字。

4. 先决条件

首先,我们从 Maven 导入 DJL Bill of Materials (BOM),以确保所有 DJL 依赖项的版本一致:

$ xml
<dependencyManagement>
    <dependencies>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>bom</artifactId>
            <version>0.36.0</version>
            <type>pom</type>
            <scope>import</scope>
        </dependency>
    </dependencies>
</dependencyManagement>

接下来,我们添加 DJL Model Zoo 模块,它提供对托管在公共仓库中的预训练模型的访问:

$ xml
<dependency>
    <groupId>ai.djl</groupId>
    <artifactId>model-zoo</artifactId>
    <version>0.36.0</version>
</dependency>

最后,DJL 需要一个特定于运行时引擎的依赖项来执行模型:

$ xml
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-engine</artifactId>
</dependency>

我们包含 pytorch-engine 库,以使用 PyTorch 作为底层的 ML 引擎。

最后,对于图像识别用例,我们将使用在经典 MNIST 数据集(28×28 手写数字 0-9)上预训练的 PyTorch 模型。

5. 图像识别用例实现

现在,我们可以实现图像识别用例,以识别图像中的单个手写数字。

首先,让我们定义一个 DigitIdentifier 类:

$ java
public class DigitIdentifier {
    public String identifyDigit(String imagePath)
            throws ModelNotFoundException, MalformedModelException, IOException, TranslateException {
        Criteria<Image, Classifications> criteria = Criteria.builder()
          .optApplication(Application.CV.IMAGE_CLASSIFICATION)
          .setTypes(Image.class, Classifications.class)
          .optFilter("dataset", "mnist")
          .build();
        ZooModel<Image, Classifications> model = criteria.loadModel();
        Classifications classifications = null;
        try (Predictor<Image, Classifications> predictor = model.newPredictor()) {
            classifications = predictor.predict(this.loadImage(imagePath));
        }
        return classifications.best().getClassName();
    }
}

在该类中,identifyDigit() 方法首先使用 Criteria 对象加载在 MNIST 数据集上训练的计算机视觉模型。然后,我们调用 model#newPredictor() 来获取 Predictor 对象。接下来,我们将数字图像的路径传递给 Predictor#predict() 方法以获取 Classifications 对象。Classifications 对象由多个 Classification 对象组成,这些对象基本上代表带有准确性分数的预测。此外,我们不需要遍历所有 Classification 对象,而是通过调用 Classifications#best() 方法来选择最佳结果。

现在,让我们看一组从 MNIST 测试数据集中提取的数字 3 的手写图像: [LOADING...]

我们将运行 identifyDigit() 方法,看看它能否预测测试图像数据集中的数字 3:

$ java
@ParameterizedTest
@ValueSource(strings = { 
  "data/3_991.png", "data/3_1028.png", 
  "data/3_9882.png", "data/3_9996.png" 
})
void whenRunModel_thenIdentifyDigitCorrectly(String imagePath) throws Exception {
    DigitIdentifier digitIdentifier = new DigitIdentifier();
    String identifiedDigit = digitIdentifier.identifyDigit(imagePath);
    assertEquals("3", identifiedDigit);
}

参数化 JUnit 测试方法使用测试图像的路径调用 DigitIdentifier#identifyDigit() 方法。我们发现 ML 模型正确预测了图像文件中的数字。

机器学习模型进行概率性预测,因此即使对于相似的图像,它们也并不总是产生正确的结果。 训练数据集的质量和代表性在很大程度上决定了预测准确性。

6. 结论

在本文中,我们学习了 DJL API 的关键组件并实现了一个图像识别用例。这可以作为我们自行探索更多功能和用例的垫脚石。

DJL 抽象底层 ML 引擎的能力确实可以帮助 Java 开发人员为需要 ML 的应用程序做出贡献。 然而,作为先决条件,理解机器学习概念同样重要,并有助于正确采用。

像往常一样,本文中使用的源代码可以在 GitHub 上获取。 本文 A Guide to Deep Java Library 首次出现在 Baeldung 上。