CSE-107 / hw5 / hw5.py
hw5.py
Raw
from ast import Mult
import math
import json
import sys, os, itertools
from timeit import default_timer as timer

from playcrypt.primitives import *
from playcrypt.tools import *
from playcrypt.new_tools import *
from playcrypt.games.game_ufcma import GameUFCMA
from playcrypt.simulator.ufcma_sim import UFCMASim

def ADD(a,b):
    return a+b
def MULT(a,b):
    return a*b
def INT_DIV(a,N):
    return (a//N, a%N)
def MOD(a,N):
    return a%N
def EXT_GCD(a,N):
    return egcd(a,N)
def MOD_INV(a,N):
    res = modinv(a,N)
    if res == None:
        raise ValueError("Inverse does not exist.")
    return res
def MOD_EXP(a,n,N):
    return exp(a,n,N)

"""
Note: As usual, our convention is that the running time of an adversary does not
include the time taken by game procedures to compute responses to adversary's queries.
"""

"""
1. Problem 1 [8 points] Let p>=3 be a prime and g \in Z_p^* be a generator of Z_p^*.
We define below the key generation algorithm K1 and the encryption algorithm E1.
The message M must be in Z_p^*, meaning the message space is Z_p^I. Let k be the bit-length
of p. (We will define p and g for you.)
"""

def K1():
    x = random_Z_N(p - 1)
    X = MOD_EXP(g, x, p)
    sk = x
    pk = X
    return (pk, sk)

def E1(pk,M):
    X = pk
    if not in_Z_N_star(M, p):           # if M not in Z_p^I then return fail
        raise ValueError("Message not in appropriate domain")
    y = random_Z_N(p - 1)          # Assign to y a random element of Z_{p-1}
    Y = MOD_EXP(g, y, p)                # Y <- g^y mod p
    Z = MOD_EXP(X, y, p)                # Z <- X^y mod p = (g^x)^y mod p
    W = MOD(Z * M, p)                   # W <- (Z * M) mod p
    return (Y, W)

"""
Let j be the bit-length of p. Specify an O(j^3)-time decryption algorithm D1 such that
AE1 = (K1, E1, D1) is an asymmetric encryption scheme satisfying the correct decryption
property.
"""

# def find_exp(Y):
#     y = 0
#     for i in range(p):
#         if(MOD_EXP(g,i,p) == Y):
#             y = i
#             break
#     return y

# C is a tuple
# Discrete Logarithm problem ; took me 2 hrs
def D1(sk, C):
    """
    This is the decryption algorithm that the problem is asking for.
    :param sk: The secret key used to decrypt the message
    :param C: The ciphertext to be decrypted
    :return: return the decryption on the ciphertext C
    """
    (Y, W) = C

    # X = MOD_EXP(g, sk, p)         # Not needed
    Z = MOD_EXP(Y, sk, p)
    M = MOD(MOD_INV(Z,p) * W, p)
    return M

    pass


"""
2. Problem 2 [6 points] Let p be a k-bit prime such that q = (p - 1)/2 is also prime, and assume
k >= 2048. Let g \in Z_p^* be a generator of Z_p^* and h = g^2. (These are public quantities, known
to all parties including the adversary.) We will define variables p, q and h for you.
Consider the family of functions T : (Z_q^*)^2 x Z_q^* -> Z_p^* defined below:
"""

def T(K, M):
    q = (p-1) // 2          # do not modify this
    h = MOD(MULT(g,g), p)   # do not modify this
    """
    :param K: The public key that is in (Z_q^*)^2
    :param M: The plaintext that must be in Z_q^*
    :return: return the output of the family of functions.
    """
    if not in_Z_N_star(M,q):
        raise ValueError("Message not in appropriate domain.")
    if not in_Z_N_star(K[0],q):
        raise ValueError("Key 0 not in appropriate domain.")
    if not in_Z_N_star(K[1],q):
        raise ValueError("Key 1 not in appropriate domain.")
    a = K[0]
    b = MOD(MULT(M, K[1]), q)
    x = MOD_INV(a, q)
    y = MOD_INV(b, q)
    u = MOD(ADD(x, y), q)
    W = MOD_EXP(h, u, p)

    return W

"""
The message M must be in Z_q^*, meaning only elements of Z_q^* are allowed as
messages. We let k be the bit-length of p.

Specify an O(k^3)-time adversary A making three Tag query such that Adv_{T}^{uf-cma}(A) = 1.
The messages in the Tag queries, and the one returned by A, must be in Z_q^*.
"""
# p, q, h are known
def A(tag):
    q = (p-1) // 2          # do not modify this
    h = MOD(MULT(g,g), p)   # do not modify this
    """
    You must fill in this method. This is the adversary that the problem is
    asking for. Returns a (message, tag) pair.
    :param tag: This is an oracle supplied by GameUFCMA, you can call this
    oracle to get a "tag" for the data you pass into it.
    """
    # Multiplication arithmetic: (a mod p)(b mod p) mod p = (a * b) mod p
    # print(MOD(MULT(W1, MOD_INV(W1, p)),p)) ### 1
    # print(MOD(MULT(h, MULT(W1, MOD_INV(W1, p))), p) == h) ### true 
    M1 = 1
    M2 = 2
    W1 = tag(M1)
    W2 = tag(M2) 

    # GET: h^(-1/k[0] - 1/(M2*k[1]))
    W2_mod_inv = MOD_INV(W2, p)
    H_K_1_half = MOD(MULT(W1, W2_mod_inv), p)     # TRUE Found h^1/(2*k[1])
    H_K_1 = MOD(MULT(H_K_1_half, H_K_1_half), p)  # TRUE Found h^1/k[1]

    H_K_1_mod_inv = MOD_INV(H_K_1, p)             
    H_K_0 = MOD(MULT(W1, H_K_1_mod_inv),p)        # TRUE found h^1/k[0]

    """
    TESTING IMPLEMENTATION: 
    """
# Check W1 == (h^1/k[0]*h^1/k[1])
    # combined = MOD(MULT(H_K_0, H_K_1),p)        # TRUE
    # print(W1 == combined )                      # TRUE: replicate W
# Check W2 
    # M2_inv = MOD_INV(M2,q)                      
    # check_exp= MOD_EXP(H_K_1, M2_inv, p)        # Apply 1/M to h^1/K[1]  
    # combined2 = MOD(MULT(H_K_0, check_exp),p)     
    # combined2 = MOD(MULT(H_K_0, H_K_1_half),p)
    # print(W2 == combined2)                      # TRUE: replicate W2
    """
    TESTING IMPLEMENTATION
    """

    M3 = 3
    M3_inv = MOD_INV(M3, q)
    exp_hk_1_m3 = MOD_EXP(H_K_1, M3_inv, p)
    tag3 = MOD(MULT(H_K_0, exp_hk_1_m3),p)

    return (M3, tag3)


"""
==============================================================================================
The following lines are used to test your code, and should not be modified.
==============================================================================================
"""
def V(K, M, t):
    if T(K, M) == t:
        return 1
    else:
        return 0

def kgen():
    q = (p-1) // 2
    return [random_Z_N_star(q) for _ in range(2)]

def main():
    print("When j=16:")
    j = 16
    global p,g
    p = prime_between(2**(j-1),2**j)
    g = find_generator_Z_N_star(p)

    start_sol = timer()
    worked = True
    for loop in range(100):
        (pk,sk) = K1()
        M = random_Z_N_star(p)
        C = E1(pk, M)
        if M != D1(sk, C):
            print ("Your first decryption function is incorrect.")
            worked = False
            break
    end_sol = timer()
    if worked:
        print ("Your first decryption function appears correct. Your decryption function takes "+str(end_sol - start_sol)+" seconds to decrypt 100 messages.")

    sys.setrecursionlimit(2000)

    s="%s%s%s%s%s%s%s%s%s" % (
        "105769031300962262347041907699188531986468080875168500291924909352726",
        "523514845129707954505966025516424111093772977370588541844230910553861",
        "654594521893564682962587494595327088018194062158369048764578647394423",
        "878171592936776203779788139029433508667038014539146349807292798552653",
        "558648914249755271208710270667471762323995019920493552915437914409534",
        "037901781078867071136428476829769041285776781721175931798108515586171",
        "194767580162162682727491216460561250733810235821601248338156562332774",
        "779512075289235506520068278611847500719701594730286561581764577491123",
        "045250981621775358501843358439025539644963902870293437315451576463")

    s1="%s%s%s%s%s%s%s%s%s" % (
        "528845156504811311735209538495942659932340404375842501459624546763632",
        "617574225648539772529830127582120555468864886852942709221154552769308",
        "272972609467823414812937472976635440090970310791845243822893236972119",
        "390857964683881018898940695147167543335190072695731749036463992763267",
        "793244571248776356043551353337358811619975099602467764577189572047670",
        "189508905394335355682142384148845206428883908605879658990542577930855",
        "973837900810813413637456082302806253669051179108006241690782811663873",
        "897560376446177532600341393059237503598507973651432807908822887455615",
        "22625490810887679250921679219512769822481951435146718657725788231")

    p = int(s)
    q = int(s1)
    g = 5
    h = MOD(MULT(g,g), p)

    gm = GameUFCMA(2, T, V, None, kgen)
    s = UFCMASim(gm, A)

    start_sol = timer()
    print ("The advantage of your adversary is ~" + str(s.compute_advantage(10)))
    end_sol = timer()
    deltatime = end_sol - start_sol
    print("Your adversary takes "+str(deltatime)+" seconds for 10 tests.")
    if deltatime > 6:
        print("Your adversary will timeout in the autograder.")


if __name__ == "__main__":
    main()