CSE-107 / hw7 / hw7.py
hw7.py
Raw
import json
import sys, os, itertools
from collections import defaultdict
from timeit import default_timer as timer
sys.path.append(os.path.abspath(os.path.join('..')))
from playcrypt.tools import *
from playcrypt.new_tools import *
from playcrypt.primitives import *

from playcrypt.games.game import Game
from playcrypt.simulator.cca_sim import CCASim

def ADD(a,b):
    return a+b
def MULT(a,b):
    return a*b
def INT_DIV(a,N):
    return (a//N, a%N)
def MOD(a,N):
    return a%N
def EXT_GCD(a,N):
    return egcd(a,N)
def GCD(a,N):
    return egcd(a,N)[0]
def MOD_INV(a,N):
    res = modinv(a,N)
    if res == None:
        raise ValueError("Inverse does not exist.")
    return res
def MOD_EXP(a,n,N):
    return exp(a,n,N)

"""
!!!READ THIS FIRST!!!
Please name your collaborator(s) for this assignment. You can also state that
you have worked completely independently. Please refer to the course website
regarding the collaboration policy.

Your response: Independently


"""


"""
Problem 1

Let K be an RSA key-generation algorithm, returning public key (N, e) and
private key (N, d, p, q).

Let E2 be the encryption algorithm which, on inputs (N, e) and M, returns
(M || R || Z)^e mod N, where || denotes string concatenation, and (M || R || Z)
is treated as an integer. Note that the function E below is a helper function
to implement E2. We have similarly provided a helper function D which may help
you write D2.
"""

def K(k):
    e = 65537
    top = 2**(k//2)
    while True:
        p = random.randint(top//2, top)
        if GCD(e, p-1) != 1:
            continue
        if is_prime(p):
            break
    while True:
        q = random.randint(top//2, top)
        if GCD(e, q-1) != 1:
            continue
        if is_prime(q):
            break
    N = p*q
    phi_of_N = (p-1) * (q-1)
    d = MOD_INV(e, phi_of_N)
    return (N, e), (N, d, p, q)

def D(sk, c):
    if c is None:
        return None
    c = string_to_int(c)
    N, d, p, q = sk
    if c >= N or GCD(c, N) != 1:
        return None
    z = MOD_EXP(c, d, N)
    # somehow get z to be k_bytes - 1
    m = int_to_string(z, k_bytes-1)

    if len(m) != k_bytes-1:
        # If this happens, it means that m overflows the k_bytes-1
        # length. It shouldn't!
        return None
    return m

def E(pk, m):
    # we want a string with k_bytes-1 bytes exactly
    assert len(m) == k_bytes - 1
    m = string_to_int(m)
    (N, e) = pk
    if m >= N:
        return None
    if GCD(m, N) != 1:

        return None
    c = MOD_EXP(m, e, N)
    return int_to_string(c, k_bytes)

def E2(pk, m):
    assert len(m) == k_bytes - 1 - ell_bytes
    R = random_string(ell_bytes//2)
    Z = int_to_string(0, ell_bytes//2)
    X = m + R + Z
    return E(pk, X)


"""
[8 points] Write the decryption algorithm D2 such that (K, E2, D2) is a correct
asymmetric encryption scheme. D, above, may be a useful helper function, but you
only need to fill out D2 below.

NOTE: `ell_bytes` and `k_bytes` are both available quantities to use in your
algorithm. These correspond to "l" and "k" in the PDF writeup.
"""

def D2(sk, c):
    if c is None:
        return None
    if len(c) != k_bytes:
        return None
    (N, d, p, q) = sk
    X = D(sk, c)                # string
    if(X is None):
        return None
    if(len(X) != k_bytes-1):
        return None
    if(string_to_int(X[len(X)-(ell_bytes//2):len(X)]) != 0):
        return None

    m_len = k_bytes - 1 - ell_bytes
    m = X[:m_len]
    if(len(m) != k_bytes - 1 - ell_bytes):
        return None
    return m



"""
[12 points] Show that (K, E2, D2) is not IND-CCA-secure by presenting an O(k^3)
adversary A which makes 1 LR (enc) query, 1 decryption (dec) query, and achieves
advantage at least 0.95.

In the below function A, the input `enc` refers to the LR oracle, `dec` is the
decryption oracle, and `pk = enc.public_key` retrieves the public key.

NOTE: D2 must be correctly implemented above in order for the local testing code
below to accurately evaluate your adversary A. You can always use the Gradescope
autograder, which will evaluate D2 and A separately.
"""

def A(enc, dec):
    pk = enc.public_key
    N, e = pk
    M0 = int_to_string(0, k_bytes - 1 - ell_bytes)
    M1 = int_to_string(1, k_bytes - 1 - ell_bytes)

    # create a ciphertext from the returned ciphertext from enc()
    # dec(new cipher text) and do some operation of new message 
    # to get information on which world
    # GOAL: Learn partial info about ciphertext 
    C = enc(M0, M1)
    int_c = string_to_int(C)
    new_c = MOD(int_c * int_c, N)
    str_new_c = int_to_string(new_c, k_bytes)
    new_m = dec(str_new_c)                      # DEC accept
    # Adv only wins if M0 = zero string
    if(string_to_int(new_m) < 2**ell+2**(ell//2)):          # X = int(M1)*2**ell ...
        return 0
    return 1




"""
==============================================================================================
The following lines are used to test your code, and should not be modified.
==============================================================================================
"""


class GamePKCPA(Game):
    """
    This game is used to test whether or not asymmetric encryption
    schemes are secure under a chosen plaintext attack. Adversaries
    playing this game have acces to an lr(l, r) and enc oracle.

    This is the public key alternative to Playcrypt's GameLR. Ideally,
    it should go in Playcrypt.
    """

    def __init__(self, nqueries_encrypt, encrypt, key_gen, *args, **kwargs):
        """
        :param nqueries_encrypt: This should be either a nonnegative
                        integer giving the exact number of queries to
                        the encryption oracle that are expected, or a
                        range of allowed values, such as ``range(0,100)``
                        for up to 99 queries.
        :param encrypt: This should be a function that takes a public key and
                        plaintext as parameters and returns a ciphertext.
                        Its public key must have been returned by the
                        ``key_gen`` parameter.
        :param key_gen: This should be a function that generates a
                        (public, private) key pair. All extra arguments
                        that are passed to this constructor go to the
                        key_gen function.
        """
        super().__init__()
        self.nqueries_encrypt = nqueries_encrypt
        self.encrypt = encrypt
        if type(self.nqueries_encrypt) == int:
            self.nqueries_encrypt = range(nqueries_encrypt, nqueries_encrypt+1)
        elif type(self.nqueries_encrypt) == range:
            pass
        else:
            raise ValueError("Incorrect type for nqueries_encrypt",
                    type(self.nqueries_encrypt))
        self.key_gen = key_gen
        self.key_args = args
        self.key_kwargs = kwargs
        self.b = -1
        self.c_list = []
        self.public_key = None
        self.private_key = None
        self.message_pairs = []
        self.warnings = defaultdict(int)
        # We're playing some nasty tricks here.
        self.lr = None

    class enc_oracle(object):
        def __init__(self, parent):
            self.public_key = parent.public_key
            self.__parent = parent
        def __call__(self, *args, **kwargs):
            return self.__parent.lr_(*args, **kwargs)

    def initialize_(self, b=None):
        """
        This method initializes the game, generates a new key, and selects a
        random world if needed.

        :param b: This is an optional parameter that allows the simulator
                  to control which world the game is in. This allows for
                  more exact simulation measurements.
        """
        pk, sk = self.key_gen(*self.key_args, **self.key_kwargs)
        self.public_key = pk
        self.lr = GamePKCPA.enc_oracle(self)
        if b is None:
            self.b = random.randrange(0,2,1)
        self.b = b
        self.message_pairs = []
        return pk, sk

    def initialize(self, b=None):
        pk, sk = self.initialize_(b)
        # We only return the public key to the adversary. We don't even
        # store the secret key in this game.
        # GamePKCCA does differently, of course.
        return pk

    def lr_(self, l, r):
        """
        This is an lr oracle. It returns the encryption of either the left or
        or right message that must be of equal length. A query for a
        particular pair is only allowed to be made once.

        :param l: Left message.
        :param r: Right message.
        :return: Encryption of left message in left world and right message in
                 right world. If the messages are not of equal length then
                 ``None`` is returned.
        """
        if len(l) != len(r):
            self.warnings['encryption oracle called with messages of different length'] += 1
            return None
        if (l,r) in self.message_pairs:
            self.warnings['encryption oracle called several times with identical message pairs'] += 1
            return None
        self.message_pairs.append((l,r))
        if self.b == 1:
            c = self.encrypt(self.public_key, r)
        else:
            c = self.encrypt(self.public_key, l)
        self.c_list.append(c)
        return c

    def finalize(self, guess):
        # check number of encryptions
        for k, v in self.warnings.items():
            print("reported warning (%d times): %s" % (v, k))
        constraint = self.nqueries_encrypt
        called = len(self.message_pairs)
        if called not in constraint:
            raise ValueError("The number of encryption queries of the adversary (%d) must be in %s" % (called, constraint))
        return guess == self.b


class GamePKCCA(GamePKCPA):
    """
    This game is used to test whether or not asymmetric encryption
    schemes are secure under a chosen ciphertext attack. Adversaries
    playing this game have acces to an lr(l, r) and dec oracle.

    This is the public key alternative to Playcrypt's GameCCA. Ideally,
    it should go in Playcrypt.
    """

    def __init__(self, nqueries_encrypt, encrypt, nqueries_decrypt, decrypt, key_gen, *args, **kwargs):
        """
        :param nqueries_encrypt: This should be either a nonnegative
                        integer giving the exact number of queries to
                        the encryption oracle that are expected, or a
                        range of allowed values, such as ``range(0,100)``
                        for up to 99 queries.
        :param encrypt: This should be a function that takes a public key and
                        plaintext as parameters and returns a ciphertext.
                        Its public key must have been returned by the
                        ``key_gen`` parameter.
        :param nqueries_decrypt: Same as ``nqueries_encrypt``, but for
                        the decryption oracle.
        :param decrypt: This should be a function that takes a private
                        key and ciphertext as parameters and returns the
                        plaintext message.
                        Its private key must have been returned by the
                        ``key_gen`` parameter.
        :param key_gen: This should be a function that generates a
                        (public, private) key pair. All extra arguments
                        that are passed to this constructor go to the
                        key_gen function.
        """
        super().__init__(nqueries_encrypt, encrypt, key_gen, *args, **kwargs)
        self.nqueries_decrypt = nqueries_decrypt
        self.decrypt = decrypt
        if type(self.nqueries_decrypt) == int:
            self.nqueries_decrypt = range(nqueries_decrypt, nqueries_decrypt+1)
        elif type(self.nqueries_decrypt) == range:
            pass
        else:
            raise ValueError("Incorrect type for nqueries_decrypt")
        self.key_gen = key_gen
        self.key_args = args
        self.key_kwargs = kwargs
        self.c_list = []
        self.private_key = None
        self.ncalls_decrypt = 0
        self.warnings = defaultdict(int)

    def initialize(self, b=None):
        pk, sk = super().initialize_(b)
        self.private_key = sk
        self.c_list = []
        self.ncalls_decrypt = 0
        return self.public_key

    def dec(self, c):
        """
        This is a decryption oracle. The adversary can query to decrypt any
        ciphertext that it did not receive from the lr oracle.

        :param c: Ciphertext to decrypt.
        :return: Decryption of ciphertext if valid. None otherwise.
        """
        self.ncalls_decrypt += 1
        if c in self.c_list:
            self.warnings['decryption oracle called on previously returned ciphertext'] += 1
            return None
        return self.decrypt(self.private_key, c)

    def finalize(self, guess):
        result = super().finalize(guess)
        # check number of encryptions
        for k, v in self.warnings.items():
            print("reported warning (%d times): %s" % (v, k))
        constraint = self.nqueries_decrypt
        called = self.ncalls_decrypt
        if called not in constraint:
            raise ValueError("The number of decryption queries of the adversary (%d) must be in %s" % (called, constraint))
        return result


if __name__ == '__main__':
    k = 1024
    k_bytes = k//8
    pk, sk = K(k)
    N = pk[0]
    m = random_string(k_bytes-1)
    c = E(pk, m)
    mm = D(sk, c)
    assert mm == m

    ell = 128
    ell_bytes = ell // 8
    ok = True
    for i in range(10):
        m = random_string(k_bytes-1-ell_bytes)
        c = E2(pk, m)
        mm = D2(sk, c)
        if mm != m:
            ok = False
            print("Your decryption function is incorrect.")
            break
    if ok:
        print ("Your decryption function appears correct.")

    G = GamePKCCA(1, E2, 1, D2, K, k)
    sim = CCASim(G, A)
    nrounds = 10
    t0 = timer()
    adv = sim.compute_advantage(nrounds)
    dt = timer() - t0
    print(f"[k={k} ell={ell}] Your adversary A ran {2*nrounds} times in {dt:#.1f} seconds")
    print(f"[k={k} ell={ell}] The advantage of your adversary A is approximately {adv}")