目录
  1. 1. dataset类
  2. 2. Dataloader
    1. 2.1. 训练流程
PYtorch中数据集的处理和载入

dataset类

所谓数据集,无非就是一组{x:y}的集合吗,你只需要在这个类里说明“有一组{x:y}的集合”就可以了。

对于图像分类任务,图像+分类

对于目标检测任务,图像+bbox、分类

对于超分辨率任务,低分辨率图像+超分辨率图像

对于文本分类任务,文本+分类
以下是官方文档中的内容

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16

class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""

def __getitem__(self, index):
raise NotImplementedError

def __len__(self):
raise NotImplementedError

def __add__(self, other):
return ConcatDataset([self, other])

上面的代码是pytorch给出的官方代码,其中getitemlen是子类必须继承的。

很好解释,pytorch给出的官方代码限制了标准,你要按照它的标准进行数据集建立。首先,getitem就是获取样本对,模型直接通过这一函数获得一对样本对{x:y}。len是指数据集长度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from torch.utils.data import Dataset

class MyDataSet(Dataset):
def __init__(self, dataset_type, transform=None, update_dataset=False):
"""
dataset_type: ['train', 'test']
"""

dataset_path = '/home/muzhan/projects/dataset/'

if update_dataset:
make_txt_file(dataset_path) # update datalist

self.transform = transform
self.sample_list = list()
self.dataset_type = dataset_type
f = open(dataset_path + self.dataset_type + '/datalist.txt')
lines = f.readlines()
for line in lines:
self.sample_list.append(line.strip())
f.close()

def __getitem__(self, index):
item = self.sample_list[index]
# img = cv2.imread(item.split(' _')[0])
img = Image.open(item.split(' _')[0])
if self.transform is not None:
img = self.transform(img)
label = int(item.split(' _')[-1])#label也放在datalist.txt文件中,使用_隔开
return img, label

def __len__(self):
return len(self.sample_list)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def __getitem__(self, index):

with Image.open(self.names[index]) as img:
bins = 8 if self.augment else 1
extractor = PatchExtractor(img=img, patch_size=PATCH_SIZE, stride=self.stride)
b = torch.zeros((bins, extractor.shape()[0] * extractor.shape()[1], 3, PATCH_SIZE, PATCH_SIZE))
for k in range(bins):

if k % 4 != 0:
img = img.rotate((k % 4) * 90)

if k // 4 != 0:
img = img.transpose(Image.FLIP_LEFT_RIGHT)

extractor = PatchExtractor(img=img, patch_size=PATCH_SIZE, stride=self.stride)
patches = extractor.extract_patches()

for i in range(len(patches)):
patch_tensor = transforms.ToTensor()(patches[i])
b[k, i] = transforms.Normalize((0.486, 0.459, 0.408),(0.229, 0.224, 0.225))(patch_tensor)

b = b.view((-1, 3, PATCH_SIZE, PATCH_SIZE))
label = self.labels[self.names[index]]

return b, label

Dataloader

Dataset只负责数据的抽象。前面提到过,在训练神经网络时,最好是对一个batch的数据进行操作,同时还需要对数据进行shuffle和并行加速等。对此,PyTorch提供了DataLoader帮助我们实现这些功能

1
2
3
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, 
num_workers=0, collate_fn=default_collate, pin_memory=False,
drop_last=False)

dataset:加载的数据集(Dataset对象)
batch_size:batch size
shuffle::是否将数据打乱
sampler: 样本抽样,后续会详细介绍
num_workers:使用多进程加载的进程数,0代表不使用多进程
collate_fn: 如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可
pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些
drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃

img

训练流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
定义基础类 BaseModel
class BaseModel:
def __init__(self, args, net, weights):
self.args = args
self.weights = weights # 存储权重的位置
self.net = net.to(args.device)

def load(self):
try:
if os.path.exists(self.weights):
print('Loading classifier...')
self.net.load_state_dict(torch.load(self.weights))
except:
print('Failed to load pre-trained network')

def save(self):
print('Saving model to "{}"'.format(self.weights))
torch.save(self.net.state_dict(), self.weights)
继承该class:
class Model(BaseModel):
def __init__(self, args, net):
super(Model, self).__init__(args, net, args.net_name)

def train(self):
train_loss和validation_loss根据epochs数量初始化
train_loss = np.zeros(self.args.epochs)

声明Dataset和Dataloader

定义criterion、optimizer以及学习率调整策略scheduler

for epoch in range(self.args.epochs):
self.net.train()
学习率迭代 scheduler.step()
for index, (images, labels) in enumerate(train_loader):
#enumerate()获得数组中的索引index(指的是batch的index)及内容dataset模块中的index体现batchsize(image,labels) 16*3*224*224
#将得到的数据传给显卡
images, labels = images.to(self.args.device), labels.to(self.args.device)
# 一个Batch有最少一张图片,计算loss的时候是对Batch_size张图片的loss对weight的导数的平均数,所以会有一个Batch_size张图片loss累加的计算的过程,
# 这时候在计算新的导数的时候,是要进行一次清零才能计算新一轮Batch中Batch_size张图片的导数
optimizer.zero_grad() #每一batch结束后梯度清0
output = self.net(images)#网络中输入图片得到输出
loss = criterion(output, labels)# 通过criterion计算loss
loss.backward()
optimizer.step()# optimizer.step()通常用在每个mini-batch之中更新模型参数

# 1:返回每一行中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引) 0:列
# 元素一个tensor对应的坐标一个tensor
_, predicted = torch.max(output.data, 1)#.data返回和output相同数据的tensor
correct += torch.sum(predicted == labels)
#train_loss[epoch]是将本epoch中所有loss之和,而一各batch的loss的计算是通过batch中图片的平均得到的
train_loss[epoch] += loss.item() * self.args.batch_size
total += len(images)

if index > 0 and index % self.args.log_interval == 0: # 每个500iters,输出一次
print('Epoch: {}/{} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, Accuracy: {:.2f}%'.format(
epoch + 1,
self.args.epochs,
index * len(images), #之间相差是500*16
len(train_loader.dataset),# 取得是TrainDataset中的__len__函数
100. * index / len(train_loader),
loss.item(),
100 * correct / total
))
# 一个epoch训练完之后计算最终的loss(累加的loss除以数据长度)
train_loss[epoch] = train_loss[epoch] / total
#测试validation集的正确率
validate_loss[epoch], validate_acc = self.validate()
if validate_acc > best_acc:
best_acc = validate_acc
self.save()
#所有epoch结束,输出结果
print('\nEnd of training, best class accuracy: {}'.format(best_acc))

plt.plot(train_loss,'b-',label='train_loss')
plt.plot(validate_loss,'r--',label='validate_loss')

plt.xlabel('epoches')
plt.title("Loss curve")
plt.legend()
plt.savefig(self.args.net_name+'.jpg')


def validate(self, verbose=False):
self.net.eval()

loss = 0
correct = 0
dataset=TrainDataset(path=self.args.dataset_path + VALIDATION_PATH,
stride=self.args.validate_stride)

validate_loader = DataLoader(
dataset=dataset,
batch_size=self.args.batch_size,
shuffle=False,
num_workers=0
)

criterion = nn.CrossEntropyLoss(size_average=False)

print('\nEvaluating classifier')


with torch.no_grad(): #不会记录梯度,减少所需的显存
for images, labels in validate_loader:
images = images.to(self.args.device)
output = self.net(images)

loss += criterion(output.cpu(), labels).item()#返回的是tensor,然后item获得其中数据,对于tensor要求是单个元素
_, predicted = torch.max(output.cpu(), 1)
correct += (predicted == labels).sum().item()

loss /= len(dataset)
acc = correct / len(dataset)

print('Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
loss,
correct,
len(dataset),
100. * correct / len(dataset)
))

return loss, acc

# Cross validation
def ctest(self, path, augment=False):
self.load()
self.net.eval()

dataset = cTestDataset(path=path, stride=self.args.test_stride, augment=augment)
data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False)

stime = datetime.datetime.now()
print('\t sum\t\t max\t\t maj')

labels = []
maj_pred = []
prop_pred = np.zeros((len(dataset), len(LABELS)), dtype=float)
maj_correct = 0
with torch.no_grad():
for index, (image, label) in enumerate(data_loader):
image = image.squeeze()
image = image.to(self.args.device)

output = self.net(image)
_, predicted = torch.max(output.data, 1)

# the following measures are prioritised based on [invasive, insitu, benign, normal]
# the original labels are [normal, benign, insitu, invasive], so we reverse the order using [::-1]
# output data shape is 12x4 注意这里调用的class和测试的时候是不一样的
# sum_prop: sum of probabilities among y axis: (1, 4), reverse, and take the index of the largest value
# max_prop: max of probabilities among y axis: (1, 4), reverse, and take the index of the largest value
# maj_prop: majority voting: create a one-hot vector of predicted values: (12, 4), sum among y axis: (1, 4), reverse, and take the index of the largest value
sum_prob = 3 - np.argmax(np.sum(np.exp(output.data.cpu().numpy()), axis=0)[::-1])
max_prob = 3 - np.argmax(np.max(np.exp(output.data.cpu().numpy()), axis=0)[::-1])
maj_prob = 3 - np.argmax(np.sum(np.eye(4)[np.array(predicted).reshape(-1)], axis=0)[::-1])

labels.append(label.item())

maj_pred.append(maj_prob)
prop_pred[index, :] = np.sum(np.eye(4)[np.array(predicted).reshape(-1)], axis=0) / len(image)


print('{}) \t {} \t {} \t {} \t {}'.format(
str(index + 1).rjust(2, '0'),
LABELS[sum_prob].ljust(8),
LABELS[max_prob].ljust(8),
LABELS[maj_prob].ljust(8),
str(label.data)))

maj_correct += torch.sum(maj_prob == label).item()

maj_acc = maj_correct / len(dataset)
print('Maj_image accuracy: {}/{} ({:.2f}%)'.format(
maj_correct,
len(data_loader.dataset),
100. * maj_acc
))

print('\nInference time: {}\n'.format(datetime.datetime.now() - stime))

return prop_pred, labels
文章作者: HazardFY
文章链接: http://hazardfy.github.io/2019/11/16/PYtorch中数据集的处理和载入/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 HazardFY's BLOG
打赏
  • 微信
  • 支付寶