#!/usr/bin/python3

# "Crypto" might need "pip3 install pycrypto" if it's not installed

import Crypto.Random
import random
from Crypto.Cipher import AES


def keygen():
    """Returns key"""
    return Crypto.Random.new().read(256//8)

zero = bytes(AES.block_size)
    
def encrypt(k,msg):
    """Encrypts a message."""
    assert isinstance(k,bytes), "Key must be a byte sequence"
    assert isinstance(msg,bytes), "Msg must be a byte sequence"
    iv = Crypto.Random.new().read(AES.block_size)
    cipher = AES.new(k,AES.MODE_CBC,iv)
    return iv+cipher.encrypt(zero+msg)


def decrypt(k,c):
    """Decrypt a ciphertext. For wrong keys, it returns None (instead of a garbage ciphertext)"""
    assert isinstance(k,(str,bytes)), "Key must be a string/byte sequence"
    assert isinstance(c,(str,bytes)), "Ciphertext must be a string/byte sequence"
    iv = c[:16]
    cipher = AES.new(k,AES.MODE_CBC,iv)
    msg = cipher.decrypt(c[16:])
    if msg[:16] != zero: return None
    return msg[16:]

# Test
k = keygen()
m = b"hello hello xxxx"
c = encrypt(k,m)
assert decrypt(k,c) == m

# Test
k = keygen()
k2 = keygen()
m = b"hello hello xxxx"
c = encrypt(k,m)
assert decrypt(k2,c) is None


def make_gate(k_left_0,k_left_1,k_right_0,k_right_1,m00,m01,m10,m11):
    """Create an "garbled" gate. k_left_0,k_left_1 are the two keys
    corresponding to the left input, k_right_0,k_right_1 to the right
    input, and m00,m01,m10,m11 are the four possible outputs of that
    gate (you can assume that they are 16 byte values). The return value 
    is a tuple of ciphertexts."""

    TODO
    
def eval_gate(k_left, k_right, gate):
    """Evaluates a garbled gate. That is, if gate was created by make_gate,
    and k_left==k_left_i, k_right=k_right_j, then this function returns mij."""

    TODO

    
# Test: make_gate
k_left_0 = keygen()
k_left_1 = keygen()
k_right_0 = keygen()
k_right_1 = keygen()
m00 = b"This is m00....."
m01 = b"This is m01....."
m10 = b"This is m10....."
m11 = b"This is m11....."
gate = make_gate(k_left_0,k_left_1,k_right_0,k_right_1,m00,m01,m10,m11)

# Test: eval_gate (four cases)
k_left = k_left_0
k_right = k_right_0
m = eval_gate(k_left,k_right,gate)
assert m==m00

k_left = k_left_0
k_right = k_right_1
m = eval_gate(k_left,k_right,gate)
assert m==m01

k_left = k_left_1
k_right = k_right_0
m = eval_gate(k_left,k_right,gate)
assert m==m10

k_left = k_left_1
k_right = k_right_1
m = eval_gate(k_left,k_right,gate)
assert m==m11

