import torch
import torch.nn as nn
import torch.nn.functional as F


class ImageEmbeddingCNN(nn.Module):
    def __init__(self, embed_size=128, K=10):
        super(ImageEmbeddingCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

        self.fc1 = nn.Linear(
            256 * 6 * 6, 128
        )
        self.fc2 = nn.Linear(128, 32)

        self.embedding = nn.Linear(32, embed_size)
        self.embed_size = embed_size
        self.K = K
        
    def forward(self, x):
        x = x.view(-1, 96, 96, 3)
        x = x.permute(0, 3, 1, 2)
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))

        x = x.reshape(-1, 256 * 6 * 6)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))

        embedding = self.embedding(x)

        embedding = embedding.view(-1, self.K, self.embed_size)

        return embedding


def main():
    x0 = torch.zeros([10, 96, 96, 3])
    x0 = x0.permute(0, 3, 1, 2)
    model = ImageEmbeddingCNN()
    print(model(x0).shape)


if __name__ == "__main__":
    main()
