Ohhnews

分类导航

$ cd ..
Jetbrains Blog原文

如何在 PyCharm 中训练你的第一个 TensorFlow 模型

#tensorflow#pycharm#机器学习#深度学习#数据科学

这是一篇来自 YouTube 频道 Back To Engineering 创始人 Iulia Feroli 的客座文章。

[LOADING...]

TensorFlow 是一个强大的开源框架,用于构建机器学习和深度学习系统。其核心是处理张量(即多维数组),并提供诸如 Keras 之类的高级库,使您可以轻松地将原始数据转换为可训练、评估和部署的模型。

TensorFlow 能够帮助您处理整个流水线:加载和预处理数据、从层(layers)和激活函数(activations)组装模型、使用优化器和损失函数进行训练,以及导出模型以供服务使用,甚至在边缘设备上运行(包括 Raspberry Pi 和其他微控制器上的轻量级 TensorFlow Lite 模型)。

如果您希望构建数据驱动的应用程序、进行神经网络原型设计,或将模型交付到生产环境或设备中,学习 TensorFlow 将为您提供一套连贯且支持完善的工具包,助您从构思走向部署。

如果您是 TensorFlow 的新手,请先观看这篇 简短的概览视频。我在视频中解释了什么是张量、神经网络、层,以及为什么 TensorFlow 非常适合实现“数据 → 模型 → 部署”的流程,我还通过一个类似乐高积木的分类示例解释了这一切。

在这篇博文中,我将带您了解一个精简的 TensorFlow 实现 Notebook,以便我们从实践经验入手。您也可以观看配套的演示视频进行同步学习。

今天我们将探索一个非常简单的用例:加载 Fashion MNIST 数据集,构建两个简单的 Keras 模型,进行训练和对比,然后深入分析可视化结果(预测结果、置信度柱状图、混淆矩阵)。我保持了代码的极简和可读性,以便您可以专注于核心概念,同时您也会看到 PyCharm 是如何在此过程中提供帮助的。

逐步训练 TensorFlow 模型

在 PyCharm 中起步

我们将利用 PyCharm 的原生 Notebook 集成来构建 我们的项目。这样,我们可以检查流水线的每一步,并在过程中使用辅助可视化。我们将 创建一个新项目生成一个虚拟环境 来管理依赖项。

如果您运行的是附带仓库中的代码,可以直接从 requirements 文件安装依赖。如果您希望通过更多模型的可视化来扩展此示例,可以使用 PyCharm 的包管理器助手轻松 安装升级 包。

加载 Fashion MNIST 并检查数据

Fashion MNIST 是一个极佳的入门数据集,因为图像很小(28×28 像素),具有直观的意义且易于理解。它们以黑白像素图像的形式呈现各种服装类型,并为分类任务提供了相关的标签。我们可以先通过 matplotlib 函数打印一些图像来查看数据样本:

[LOADING...]

$ python
fig, axes = plt.subplots(2, 5, figsize=(10, 4))
for i, ax in enumerate(axes.flat):
    ax.imshow(x_train[i], cmap='gray')
    ax.set_title(class_names[y_train[i]])
    ax.axis('off')
plt.show()

两个简单的模型(快速实验)

$ python
model1 = models.Sequential([
    layers.Flatten(input_shape=(28, 28)),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
])
model2 = models.Sequential([
    layers.Flatten(input_shape=(28, 28)),
    layers.Dense(128, activation='relu'),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
])

编译并训练您的第一个模型

接下来,我们可以编译并训练我们的第一个 TensorFlow 模型。借助 PyCharm 的代码补全功能和文档访问功能,您可以立即获得构建这些简单代码块的建议。

对于 TensorFlow 的首次尝试,只需在 IDE 中按几次 Tab 键,我们就能快速构建一个可运行的模型。我们使用推荐的标准优化器和损失函数,并追踪准确率。您可以通过调整层的数量或类型以及其他参数来构建多个模型进行对比。

$ python
model1.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)
model1.fit(x_train, y_train, epochs=10)
model2.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)
model2.fit(x_train, y_train, epochs=15)

评估并对比 TensorFlow 模型的性能

$ python
loss1, accuracy1 = model1.evaluate(x_test, y_test)
print(f'Accuracy of model1: {accuracy1:.2f}')
loss2, accuracy2 = model2.evaluate(x_test, y_test)
print(f'Accuracy of model2: {accuracy2:.2f}')

一旦模型训练完成(随着每个单元格的运行,您可以直观地看到 epoch 的进度),我们就可以立即评估模型的性能。

在我的实验中,model1 的准确率约为 0.88,虽然 model2 的准确率略高,但它的训练时间增加了 50%。这就是您需要权衡的地方:微小的准确率提升是否值得投入额外的计算资源和模型复杂度?

我们可以通过生成新预测数据集的 DataFrame 实例,进一步深入分析模型运行结果。在这里,我们还可以利用 describe 等内置函数快速获得初步的统计印象:

[LOADING...]

$ python
predictions = model1.predict(x_test)
import pandas as pd
df_pred = pd.DataFrame(predictions, columns=class_names)
df_pred.describe()

然而,最有用的统计数据是将模型的预测与数据集中真实的“地面实况(ground truth)”标签进行比较。我们还可以按类别细分:

$ python
y_pred = model1.predict(x_test).argmax(axis=1)
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8,6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()
print('Classification report:')
print(classification_report(y_test, y_pred, target_names=class_names))

从中我们可以注意到,不同类别的服装准确率差异很大。一种可能的解释是,裤子与 T 恤或衬衫等服装类型截然不同,而后者更容易被混淆。

当然,作为人类,我们通过观察图像就能发现这种细微差别,但模型只能接触到像素值的矩阵。不过,数据似乎确实印证了我们的直觉。我们可以进一步构建更全面的可视化来验证这一假设。

[LOADING...]

$ python
import numpy as np
import matplotlib.pyplot as plt
# 挑选 8 个错误示例
y_pred = predictions.argmax(axis=1)
wrong_idx = np.where(y_pred != y_test)[0][:8]  # 前 8 个错误
n = len(wrong_idx)
fig, axes = plt.subplots(n, 2, figsize=(10, 2.2 * n), constrained_layout=True)
for row, idx in enumerate(wrong_idx):
    p = predictions[idx]
    pred = int(np.argmax(p))
    true = int(y_test[idx])
    axes[row, 0].imshow(x_test[idx], cmap="gray")
    axes[row, 0].axis("off")
    axes[row, 0].set_title(
        f"WRONG  P:{class_names[pred]} ({p[pred]:.2f})  T:{class_names[true]}",
        color="red",
        fontsize=10
    )
    bars = axes[row, 1].bar(range(len(class_names)), p, color="lightgray")
    bars[pred].set_color("red")
    axes[row, 1].set_ylim(0, 1)
    axes[row, 1].set_xticks(range(len(class_names)))
    axes[row, 1].set_xticklabels(class_names, rotation=90, fontsize=8)
    axes[row, 1].set_ylabel("conf", fontsize=9)
plt.show()

该表格生成了一个视图,让我们能够探索模型在预测时的置信度:通过观察每个类别的权重,我们可以看出模型在哪种情况下存在疑虑(即多个类别权重较高),而在哪种情况下非常确定(只有一个猜测)。这些示例进一步证实了我们的直觉:模型更容易混淆不同类型的上衣。

结论

大功告成!我们已经成功设置并训练了第一个模型,并从数据和模型结果中获得了一些数据科学洞察。此时使用 PyCharm 的部分功能可以加速实验过程,它能提供文档访问权限,并直接在单元格中应用代码补全。我们甚至可以使用 AI Assistant 来帮助生成进一步评估 TensorFlow 模型性能和分析结果所需的图表。

您可以 亲自尝试这个 Notebook,或者更好的是,尝试使用这些工具从零开始生成它,以获得更具实践性的学习体验。

后续步骤

这个 Notebook 是一个极简的教学起点。以下是一些后续尝试的实用建议:

  • 将基础的 Dense 层替换为小型 CNN(Conv2D → MaxPooling → Dense)。
  • 添加 Dropout 或 Batch Normalization 以减少过拟合。
  • 应用数据增强(随机平移/旋转)以提高泛化能力。
  • 使用 EarlyStoppingModelCheckpoint 等回调函数,以提高训练效率并保留最佳权重。
  • 导出 SavedModel 用于服务器,或转换为 TensorFlow Lite 用于边缘设备(Raspberry Pi、微控制器)。

常见问题解答

我应该在什么时候使用 TensorFlow?

TensorFlow 最适合用于构建需要扩展、进入生产环境或在不同环境(云、移动端、边缘设备)中运行的机器学习或深度学习模型。

TensorFlow 特别适用于大规模模型和神经网络,包括那些需要强大部署支持(TensorFlow Serving, TensorFlow Lite)的场景。对于研究原型,TensorFlow 是可行的,但使用轻量级框架进行实验更为常见。

TensorFlow 可以在 GPU 上运行吗?

是的,TensorFlow 可以运行在 GPU 和 TPU 上。此外,使用 GPU 可以显著加快训练速度,特别是对于具有大型数据集的深度学习模型。最棒的是,如果配置得当,TensorFlow 会自动使用可用的 GPU。

TensorFlow 中的“损失(Loss)”是什么?

损失(也称为损失函数)衡量模型预测值与实际目标值之间的差距。TensorFlow 中的损失是一个代表预测值与目标值之间距离的数值。常见示例包括:

  • MSE(均方误差),常用于回归任务。
  • 交叉熵损失(Cross-entropy loss),常用于分类任务。

我应该使用多少个 epoch?

没有固定的 epoch 数,因为它取决于您的数据集和模型。典型的方法包括:

  • 从保守的数值开始(10-50 个 epoch)。
  • 监控验证集的损失/准确率,并根据观察到的结果进行调整。
  • 使用提前停止(Early Stopping)在性能提升减弱时停止训练。

一个 epoch 是对训练数据的完整遍历。遍历次数太少会导致欠拟合,而次数太多则会导致过拟合。最佳点在于模型对未见数据具有最佳泛化能力的时候。

关于作者

[LOADING...]

Iulia Feroli

Iulia 的使命是让科技变得令人兴奋、易于理解,并让新一代人能够轻松接触。

她拥有数据科学、人工智能、云架构和开源领域的背景,在架起技术深度与易用性之间的桥梁方面,她有着独特的见解。

她正在打造自己的品牌 Back To Engineering,通过该平台,她为技术爱好者、工程师和创作者建立了一个社区。从关于从零构建机器人的 YouTube 视频,到关于真实、扎实 AI 的会议演讲或主题演讲,再到技术博客和教程——Iulia 向世界分享着她的观点:如何将复杂的概念转化为开发者每天都能使用的工具。