# -*- coding: utf-8 -*-
import torch.nn as nn
import torch.nn.functional as F
class MnistCNN(nn.Module):
def __init__(self):
super(MnistCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3)
self.conv2 = nn.Conv2d(32, 64, 3)
self.fc3 = nn.Linear(1024, 128)
self.fc4 = nn.Linear(128, 10)
def forward(self, x):
h = F.relu(self.conv1(x))
h = F.relu(self.conv2(h))
h = F.dropout2d(F.max_pool2d(h, 6), p=0.25)
h = F.dropout2d(self.fc3(h.view(h.size(0), -1)), p=0.5)
h = self.fc4(h)
return F.log_softmax(h)
class CifarCNN(nn.Module):
def __init__(self):
super(CifarCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
self.bn1 = nn.BatchNorm2d(64)
self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.conv4 = nn.Conv2d(128, 128, 3, padding=1)
self.bn4 = nn.BatchNorm2d(128)
self.fc5 = nn.Linear(512, 256)
self.fc6 = nn.Linear(256, 256)
self.fc7 = nn.Linear(256, 10)
def forward(self, x):
h = F.relu(self.bn1(self.conv1(x)))
h = F.relu(self.bn2(self.conv2(h)))
h = F.max_pool2d(h, 4)
h = F.relu(self.bn3(self.conv3(h)))
h = F.relu(self.bn4(self.conv4(h)))
h = F.max_pool2d(h, 4)
h = F.relu(self.fc5(h.view(h.size(0), -1)))
h = F.relu(self.fc6(h))
h = self.fc7(h)
return F.log_softmax(h)
class Generator(nn.Module):
def __init__(self, in_ch):
super(Generator, self).__init__()
self.conv1 = nn.Conv2d(in_ch, 64, 4, stride=2, padding=1)
self.bn1 = nn.BatchNorm2d(64)
self.conv2 = nn.Conv2d(64, 128, 4, stride=2, padding=1)
self.bn2 = nn.BatchNorm2d(128)
self.deconv3 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
self.bn3 = nn.BatchNorm2d(64)
self.deconv4 = nn.ConvTranspose2d(64, in_ch, 4, stride=2, padding=1)
def forward(self, x):
h = F.leaky_relu(self.bn1(self.conv1(x)))
h = F.leaky_relu(self.bn2(self.conv2(h)))
h = F.leaky_relu(self.bn3(self.deconv3(h)))
h = F.tanh(self.deconv4(h))
return h
class Discriminator(nn.Module):
def __init__(self, in_ch):
super(Discriminator, self).__init__()
self.conv1 = nn.Conv2d(in_ch, 64, 3, stride=2)
self.conv2 = nn.Conv2d(64, 128, 3, stride=2)
self.bn2 = nn.BatchNorm2d(128)
self.conv3 = nn.Conv2d(128, 256, 3, stride=2)
self.bn3 = nn.BatchNorm2d(256)
if in_ch == 1:
self.fc4 = nn.Linear(1024, 1)
else:
self.fc4 = nn.Linear(2304, 1)
def forward(self, x):
h = F.leaky_relu(self.conv1(x))
h = F.leaky_relu(self.bn2(self.conv2(h)))
h = F.leaky_relu(self.bn3(self.conv3(h)))
h = F.sigmoid(self.fc4(h.view(h.size(0), -1)))
return h
if __name__ == "__main__":
import torch
from torch.autograd import Variable
x = torch.normal(mean=0, std=torch.ones(10, 3, 32, 32))
model = CifarCNN()
model(Variable(x))