import torch.nn as nn class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() act = nn.Sigmoid self.body = nn.Sequential( nn.Conv2d(3, 12, kernel_size=5, padding=5 // 2, stride=2), act(), nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=2), act(), nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1), act(), nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1), act(), ) self.fc = nn.Sequential(nn.Linear(768, 10)) def forward(self, x): out = self.body(x) out = out.view(out.size(0), -1) out = self.fc(out) return out