前言

什么是Kaggle?

Kaggle是全球领先的数据科学竞赛平台,提供丰富的数据集和强大的在线笔记本环境,让用户能够在云端进行数据分析、模型训练和分享。Kaggle不仅支持公开数据集,也允许用户上传自定义数据,方便进行个性化实验。

PyTorch简介

PyTorch是当前非常流行的深度学习框架之一,特别适合研究和快速原型开发。它提供了灵活的动态图机制,并拥有丰富的工具和库支持,如自动求导、神经网络模块和数据加载工具。

torchvision.datasets详解

torchvision.datasets是PyTorch生态中非常实用的模块,内置了多个常见视觉数据集的加载接口(如MNIST、CIFAR10、FashionMNIST等)。其主要功能包括:

  • 自动下载数据
  • 读取数据文件和标签
  • 支持预处理和变换(transforms)
  • 方便的训练/测试数据区分

使用时,有几个关键参数需要留意:

  • root:数据集根目录,存放数据的地方
  • train:是否加载训练集(True)或测试集(False)
  • download:数据不存在时是否自动下载
  • transform:对加载的图像进行预处理,如转Tensor、归一化
  • target_transform:对标签的处理

在Kaggle这种环境下,常用的做法是先将数据上传到/kaggle/input目录,然后复制到工作目录,再通过datasets模块的接口加载。

shutil模块简介

shutil是Python的高级文件操作模块,常用来进行文件和文件夹的复制、移动、删除等操作,较os模块更简洁且功能集中。在处理数据搬运时非常方便,例如:

  • shutil.copytree():递归复制文件夹及内容
  • shutil.rmtree():递归删除文件夹及所有内容

具体流程

1. 上传数据集

将数据集文件夹通过Kaggle界面上传,假设命名为mnist-data。上传后文件会出现在路径/kaggle/input/mnist-data下。

2. 创建目录并迁移数据

有两种常见方式将数据从输入目录复制到工作目录:

方式1:使用os结合Linux命令cp

1
2
3
4
5
6
7
8
9
import os 

output_path = "/kaggle/working/data/MNIST/raw"

# 创建输出目录(递归创建,避免路径不存在报错)
os.makedirs(output_path, exist_ok=True)

# 使用Linux命令复制文件
!cp -r /kaggle/input/mnist-data/* /kaggle/working/data/MNIST/raw

方式2:Python内置模块shutil.copytree()

1
2
3
4
5
6
7
import shutil

input_path = "/kaggle/input/mnist-data"
output_path = "/kaggle/working/data/MNIST"

# 复制整个文件夹及其内容(copytree会自动创建目标路径)
shutil.copytree(input_path, output_path)

🔔注意事项:

  • shutil.copytree()不能复制到已存在的目录,会报错。如果目录存在,建议先删除或用shutil.copytree(..., dirs_exist_ok=True)(Python 3.8+支持)。
  • os.makedirs()确保目标文件夹存在。

3.迁移后数据目录展示

4. PyTorch导入数据并输出数据维度

通过torchvision.datasets加载数据:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
from torchvision import datasets
from torchvision.transforms import ToTensor

train_data = datasets.MNIST(
    root="./data",         # 指定数据存放根目录
    train=True,            # 加载训练集
    download=False,        # 数据已上传,本地已有就不用自动下载
    transform=ToTensor(),  # 转换为Tensor方便PyTorch使用
    target_transform=None
)

test_data = datasets.MNIST(
    root="./data",
    train=False,
    download=False,
    transform=ToTensor(),
    target_transform=None
)

print(f"训练集样本数量: {len(train_data)}")
print(f"测试集样本数量: {len(test_data)}")

这里datasets.MNIST实际上会在root指定路径下尝试读取MNIST文件夹,并在其中查找数据文件(尤其是raw文件夹)。如果数据已上传并按照规范摆放,即可正常加载。


其余相关命令

获取默认工作目录

1
2
import os
print(os.getcwd())

一般返回:

1
/kaggle/working

更改当前工作目录

1
2
3
4
5
import os

input_path = "/kaggle/input"
os.chdir(input_path)
print(os.getcwd())  # 验证目录变更

创建目录

1
2
3
4
import os

output_path = "/kaggle/working/data/MNIST/raw"
os.makedirs(output_path, exist_ok=True)

使用shutil删除目录

1
2
3
4
import shutil

directory_path = "/kaggle/working/data/MNIST/"
shutil.rmtree(directory_path)  # 小心使用,删除目录及所有内容

使用shutil复制文件夹

1
2
3
4
5
import shutil

input_path = "/kaggle/input/mnist-data"
output_path = "/kaggle/working/data/MNIST"
shutil.copytree(input_path, output_path, dirs_exist_ok=True)  # Python 3.8+支持覆盖复制

使用Linux命令cp复制文件

1
!cp -r /kaggle/input/mnist-data/* /kaggle/working/data/MNIST/raw

总结

  • 先上传数据集到/kaggle/input目录;
  • 使用Python文件操作、shutil或Linux命令将数据复制到/kaggle/working工作目录;
  • 使用torchvision.datasets加载数据,设置root路径及transform,加载自己上传的数据;
  • 注意shutil.copytree()的新参数dirs_exist_ok可以避免已有目录错误;
  • PyTorch的数据加载不仅限于MNIST,熟悉datasets使用可以方便加载多种常用视觉数据集。