使用 Java 和 TensorFlow 训练与部署神经网络模型
训练、导出并使用 TensorFlow 模型是深入了解驱动人工智能革命的大型语言模型(LLM)构建模块的绝佳途径。由于我擅长使用 Java,我将利用它来定义一个神经网络(NN)模型,进行训练,以语言无关的格式导出,然后将其导入到 Spring Boot 项目中。当然,从零开始完成所有这些工作并不明智,因为神经网络领域有许多进展,如果完全从头理解并实现,不仅耗时,而且极易出错。因此,为了既能学习神经网络又能简化实现过程,我们将使用一个成熟的软件平台:TensorFlow。
TensorFlow 是一个成熟且强大的平台,许多人使用它来构建和训练模型,但它几乎只用于 Python 编程语言。幸运的是,该项目通过将原生库封装为 Maven 依赖项,使其可以在 Java 中使用:https://github.com/tensorflow/java。
我们将使用 CPU 进行学习和运行,因为这很简单,除了添加特定的 Maven 依赖项外,无需进行额外操作。TensorFlow 平台也支持使用 GPU,但这需要额外的配置步骤。使用 CPU 的缺点是学习速度较慢(如果训练的 epoch 次数相同,模型性能是一样的)。本教程将涵盖完整流程,从收集数据以训练我们自己的简单分类模型,到导入并使用 TensorFlow 团队提供的预训练目标检测模型。
教程
在开始之前,如果我们想要获得可用的成果,需要熟悉一些与神经网络相关的基本概念:
- 层 (Layer):这是人工神经网络中的一种结构,由按一维数组排列的多个神经元组成。
- 权重初始化 (Weight initialization):神经网络中的每一层都有权重(训练期间更新的实际值);这些权重定义了训练后的神经网络。初始值通常设置为随机值,但这些随机值的选择方式很重要。选择特定的初始化器(Initializer)实现会影响学习所需的迭代次数(这些权重会加上偏置,以向上或向下平移整体数值,从而实现更快的训练)。
- 激活函数 (Activation function):应用于层权重的函数,通常用于增加神经网络学习复杂模式所需的非线性。
- 损失函数 (Loss function):仅在训练期间由优化器使用,用于更新权重和偏置。
- 优化器 (Optimizer):利用损失函数返回的值,优化器在神经网络层中遍历并更新权重和偏置的值。
Java TensorFlow 平台提供了这些概念的便捷实现(Initializer、Loss 和 Optimizer),因此可以轻松切换它们,并观察训练和最终模型在输入数据上的表现。不要害怕实验,你会发现许多神经网络拓扑结构都是实证研究的结果,即使在发表的论文中也是如此。
用于训练我们的模型的数据集是一个在许多教程中使用的经典数据集:鸢尾花数据集 (Iris plant data set)。它包含不同鸢尾花品种的花瓣和萼片的长度与宽度信息,以及所属品种。输入文件中的信息组织方式如下例所示:
在训练期间将保持此顺序;保持一致性非常重要。共有三个品种,每个品种有 50 个值。这三个品种将成为我们的特征 (features)。因为我们有三个特征,这类问题被称为分类 (classification) —— 运行模型的输出将是 3 个百分比值,相加等于 100%(例如:0.80, 0.01, 0.19)。例如,如果我们训练的是汇率预测模型(只有一个输出值——汇率),那将是一个回归类型的问题。
在鸢尾花分类训练期间,我们将获取模型的输出(例如:0.80, 0.01, 0.19),并检查对于给定的输入,概率最高的值是否属于预期品种的位置。这意味着输出的索引将始终对应同一个品种 —— 索引和品种由我们选择,可以是任何值,但必须保持一致。对于我们的训练,我们选择以下索引:
选择这些索引后,上述示例中的 0.80 值将被我们(神经网络拓扑的设计者)解释为该品种为 Iris Setosa 的概率为 80%。
现在我们对神经网络有了基本的了解,并确定了输入和输出,让我们开始编写训练器代码。我们将使用一种简单的神经网络拓扑,称为多层感知器 (multilayer perceptron),它具有两个隐藏层。这种拓扑可以用来解决一系列问题,并且非常适合我们的鸢尾花分类问题。
[LOADING...]
首先,我们为第 1 个隐藏层选择 5 个节点(神经元),为第 2 个隐藏层选择 4 个节点。如果需要,我们可以试验这些值以提高训练时间和性能。我们如何决定何时停止训练?为此,我们将预测输出与已知答案进行比较,当正确答案达到可接受的百分比时,我们停止训练。
以下所有代码片段均摘自此仓库。
首先是构建我们的网络:
请注意,我们命名了大多数操作(tensorFlowApi.withName(...)),以便在以后使用和保存时可以轻松检索它们。接下来是读取我们将用于训练的数据:
现在我们有了网络和数据,可以开始训练了:
请注意,我们检查预测值是否与预期值相同,并在训练 epoch 结束时打印预测正确的数量 —— 这是判断何时停止训练的简单检查方法。使用这种拓扑和 4 个训练 epoch,我们在 150 个样本中猜对了 124 个,这对于教程来说是可以接受的。通过调整拓扑、随机种子等,可以轻松实现更好的性能。
现在训练已达到可接受的性能,我们可以保存模型以便共享:
导出格式是 TensorFlow 特有的,但与语言无关,这意味着使用 Java API 保存的模型可以在 TensorFlow Python API 中使用,反之亦然。加载模型就像保存它一样简单:
列出签名是一个好习惯,这样可以了解可用内容以及函数名称 —— 它们可用于将数据传入模型,或者通过使用中间输入作为输出来使用模型的部分(如果它们在保存时通过将操作添加到签名中进行了导出)。加载模型后,我们可以使用它来获取会话:Session tfSession = model.session();
从现在开始,用于训练的相同代码(除损失函数和优化器外)可用于将数据传入模型并提取预测结果:
结语
至此,我们完成了训练、导出、加载和使用我们自己的模型的过程。但在大多数情况下,我们会想要使用他人创建并在线发布(例如在 https://www.kaggle.com/models?framework=tensorFlow2 上)的模型。下载模型后,可以使用与上述相同的步骤来加载和使用它。它们大多数都有完整的在线文档,但即使缺少文档,也可以在加载模型后使用 model.signatures() 从模型中提取信息。下面的示例是针对 https://www.kaggle.com/models/tensorflow/efficientdet/tensorFlow2/d0 的(该模型在上面链接的 GitHub 项目中被加载和使用):
DZone 贡献者所表达的观点仅代表其个人观点。