diff --git a/rust/server/Cargo.lock b/rust/server/Cargo.lock index 814e658..c1e9ff8 100644 --- a/rust/server/Cargo.lock +++ b/rust/server/Cargo.lock @@ -115,6 +115,8 @@ dependencies = [ "crc32fast", "lazy_static", "num", + "num-bigint", + "num-traits", "rand", "regex", "serde", diff --git a/rust/server/Cargo.toml b/rust/server/Cargo.toml index 9a35aeb..2926ecd 100644 --- a/rust/server/Cargo.toml +++ b/rust/server/Cargo.toml @@ -13,6 +13,8 @@ sha2 = { default-features = false, version = "0.10.6" } regex = { default-features = true, version = "1.8.3" } crc32fast = { default-features = false, version = "1.3.2" } rand = { default-features = true, features = ["std_rng"], version = "0.8.5" } +num-bigint = { default-features = true, version = "0.4.3" } +num-traits = { default-features = true, version = "0.2.15" } num = { default-features = true, version = "0.4.0" } lazy_static = { default-features = false, version = "1.4.0" } diff --git a/rust/server/src/main.rs b/rust/server/src/main.rs index 1e9fe7b..94a4767 100644 --- a/rust/server/src/main.rs +++ b/rust/server/src/main.rs @@ -1,11 +1,18 @@ use bincode::{self, options, Options}; -use std::{net::{UdpSocket, SocketAddr}, io::{Read, Seek}}; -use serde::{Serialize, Deserialize}; -use sha2::{Digest,Sha512}; -use regex::Regex; use crc32fast; +use num_bigint::ToBigUint; +use num_traits::ToPrimitive; use rand::{rngs::StdRng, RngCore, SeedableRng}; -use num::{BigUint, one}; +use regex::Regex; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha512}; +use std::{ + collections::HashMap, + io::{Read, Seek}, + net::{SocketAddr, UdpSocket}, +}; + +use num::{BigUint, NumCast, One, Zero}; mod big_array; mod primality_test; @@ -14,22 +21,22 @@ use big_array::BigArray; use crate::prime_utils::prime_utils::is_prime_default; -const MAX_FRAME_PAYLOAD:u16=508; -const MAX_FRAME_PAYLOAD_U:usize=MAX_FRAME_PAYLOAD as usize; -const HEADER_SIZE:u16 = 12; -const MAX_PAYLOAD:u16 = MAX_FRAME_PAYLOAD - HEADER_SIZE; -const MAX_PAYLOAD_U:usize = MAX_PAYLOAD as usize; +const MAX_FRAME_PAYLOAD: u16 = 508; +const MAX_FRAME_PAYLOAD_U: usize = MAX_FRAME_PAYLOAD as usize; +const HEADER_SIZE: u16 = 12; +const MAX_PAYLOAD: u16 = MAX_FRAME_PAYLOAD - HEADER_SIZE; +const MAX_PAYLOAD_U: usize = MAX_PAYLOAD as usize; #[derive(Serialize, Deserialize)] struct PacketInformation { - packet_numbers: u32, //4 bytes - last_packet_size: u16, //2 bytes + packet_numbers: u32, //4 bytes + last_packet_size: u16, //2 bytes response_filename_checksum: u64, //8 bytes } //14 bytes #[derive(Serialize, Deserialize)] struct Packet { - packet_number: u32, //4 bytes + packet_number: u32, //4 bytes payload_crc32: [u8; 4], //4 bytes payload_sum_crc32: u32, //4 bytes #[serde(with = "BigArray")] @@ -39,7 +46,7 @@ struct Packet { #[derive(Serialize, Deserialize)] struct StrPacket { #[serde(with = "BigArray")] - payload: [u8; MAX_PAYLOAD_U] + payload: [u8; MAX_PAYLOAD_U], } fn pow_mod(num: BigUint, pow: BigUint, modulo: BigUint) -> BigUint { @@ -47,68 +54,125 @@ fn pow_mod(num: BigUint, pow: BigUint, modulo: BigUint) -> BigUint { let mut i = BigUint::from(0u8); while i < pow { result = (result * &num) % &modulo; - i += one::(); + i += BigUint::one(); } result } +fn nth_prime_approx(n: f64) -> f64 { + let ln_n = n.ln(); + let ln_ln_n = ln_n.ln(); + let ln_ln_ln_n = ln_ln_n.ln(); + + n * (ln_n + ln_ln_n - 1.0 + (ln_ln_ln_n - 2.0) / ln_n + - ((ln_ln_n).powi(2) - 6.0 * ln_ln_n + 11.0) / (2.0 * 2.0f64.log2() * ln_n) + + ((ln_ln_n / ln_n).powi(3)) * (1.0 / ln_n)) +} + +fn ln(x: &BigUint) -> BigUint { + let mut sum = x.clone(); + let mut term = x.clone(); + let two = BigUint::from(2u32); + let ten = BigUint::from(10u32); + let precision = 100; + + for i in 1..precision { + term /= &two; + if i % 2 == 0 { + sum += &term; + } else { + sum -= &term; + } + } + + sum *= &ten.pow(precision); + sum /= x.clone(); + + sum +} + + +//TODO: more precise approximation (Riemann R function?) +fn prime_approx(num: BigUint) -> BigUint { + //x/log(x) + let ln_num = ln(&num); + num / ln_num +} //TODO: make this faster fn new_p() -> BigUint { - let mut private_key = [0u8; 128]; + const BITS: usize = 64; + let mut private_key = [0u8; BITS]; let mut rng = StdRng::from_entropy(); rng.fill_bytes(&mut private_key); - let mut num = BigUint::from_bytes_be(&private_key); + let rand_num = BigUint::from_bytes_be(&private_key); + let mut num = nth_prime_approx(prime_approx(rand_num.clone()).to_f64().unwrap()) + .to_biguint() + .unwrap(); + println!( + "guesstimating {}th prime: {}", + prime_approx(rand_num.clone()).to_f64().unwrap(), + num + ); if is_prime_default(&num) { return num; } else { - let higher: BigUint; loop { - num += one::(); + num += BigUint::one(); if is_prime_default(&num) { - higher = num; - break; + return num; } } - num = BigUint::from_bytes_be(&private_key); - let lower: BigUint; - loop { - num -= one::(); - if is_prime_default(&num) { - lower = num.clone(); - break; - } - } - if &higher - &num > &num - &lower { - return lower; - } else { - return higher; - } } } +//TODO: test this properly +fn new_q(p: &BigUint) -> BigUint { + //find a q such that q is a primitive root of p + let mut q = BigUint::from(2u8); + loop { + if pow_mod(q.clone(), p.clone() - BigUint::one(), p.clone()) == BigUint::zero() { + q += BigUint::one(); + } else { + break; + } + } + q +} + fn main() { let port = "1337"; let timeout = 100; //ms - let local_addr: SocketAddr = ("0.0.0.0:".to_string()+port).parse().expect("Failed to parse address"); + let local_addr: SocketAddr = ("0.0.0.0:".to_string() + port) + .parse() + .expect("Failed to parse address"); let socket = UdpSocket::bind(local_addr).expect("Failed to bind socket"); - socket.set_read_timeout(Some(std::time::Duration::from_millis(timeout))).expect("set_read_timeout call failed"); + socket + .set_read_timeout(Some(std::time::Duration::from_millis(timeout))) + .expect("set_read_timeout call failed"); println!("UDP Server up and running on port {}", local_addr.port()); - let options = options().with_big_endian().allow_trailing_bytes().with_fixint_encoding(); + let options = options() + .with_big_endian() + .allow_trailing_bytes() + .with_fixint_encoding(); let hash_request_regex = Regex::new(r"[\w.-_ ]+:").unwrap(); let missing_packet_request_regex = Regex::new(r"([\w.-_ ]+)(/\d{1,10})+").unwrap(); + let p_q_lookup: HashMap = HashMap::new(); + // [ID: u128] = (p: [u8,64], q: [u8,64?]) + // TODO: find fixed size for q + loop { let mut buffer = [0u8; MAX_FRAME_PAYLOAD_U]; let res = socket.recv_from(&mut buffer); if let Ok((_received_bytes, remote_addr)) = res { - let filled_buffer = &buffer;//[..received_bytes]; + let filled_buffer = &buffer; //[..received_bytes]; let request_packet = bincode::deserialize::(filled_buffer).unwrap(); let mut request: String = request_packet.payload.iter().map(|&c| c as char).collect(); @@ -147,7 +211,9 @@ fn main() { let mut result_bytes = [0u8; 65]; result_bytes[..64].copy_from_slice(&result[..64]); - socket.send_to(&result_bytes, remote_addr).expect("Failed to send file hash"); + socket + .send_to(&result_bytes, remote_addr) + .expect("Failed to send file hash"); continue; } @@ -165,7 +231,10 @@ fn main() { let mut file = possible_file.unwrap(); for packet_number in split { - file.seek(std::io::SeekFrom::Start((packet_number.parse::().unwrap() * MAX_PAYLOAD as u64) as u64)).unwrap(); + file.seek(std::io::SeekFrom::Start( + (packet_number.parse::().unwrap() * MAX_PAYLOAD as u64) as u64, + )) + .unwrap(); let mut file_buffer = [0u8; MAX_PAYLOAD_U]; let bytes_read = file.read(&mut file_buffer).unwrap(); let mut packet = Packet { @@ -176,7 +245,9 @@ fn main() { }; packet.payload[..bytes_read].copy_from_slice(&file_buffer[..bytes_read]); let packet_bytes = options.serialize(&packet).unwrap(); - socket.send_to(&packet_bytes, remote_addr).expect("Failed to send packet"); + socket + .send_to(&packet_bytes, remote_addr) + .expect("Failed to send packet"); } continue; @@ -184,6 +255,11 @@ fn main() { println!("Received file request"); + let p = new_p(); + let q = new_q(&p); + + println!("p: {} q: {}", p, q); + let file_result = std::fs::File::open("./files/".to_string() + req); if file_result.is_err() { @@ -200,15 +276,23 @@ fn main() { let file_length = file.metadata().unwrap().len(); println!("file length: {}", file_length); - let packet_numbers = f64::ceil(file_length as f64/MAX_PAYLOAD as f64) as u32; + let packet_numbers = f64::ceil(file_length as f64 / MAX_PAYLOAD as f64) as u32; let packet_information = PacketInformation { packet_numbers, - last_packet_size: (file_length%MAX_PAYLOAD as u64) as u16, - response_filename_checksum: request_packet.payload.iter().fold(0u64, |acc, &x| acc.wrapping_add(x as u64)), + last_packet_size: (file_length % MAX_PAYLOAD as u64) as u16, + response_filename_checksum: request_packet + .payload + .iter() + .fold(0u64, |acc, &x| acc.wrapping_add(x as u64)), }; - socket.send_to(options.serialize(&packet_information).unwrap().as_slice(), remote_addr).expect("Failed to send packet information"); + socket + .send_to( + options.serialize(&packet_information).unwrap().as_slice(), + remote_addr, + ) + .expect("Failed to send packet information"); let mut sha512_hasher = Sha512::new(); @@ -231,7 +315,9 @@ fn main() { }; packet.payload[..bytes_read].copy_from_slice(&file_buffer[..bytes_read]); let packet_bytes = options.serialize(&packet).unwrap(); - socket.send_to(&packet_bytes, remote_addr).expect("Failed to send packet"); + socket + .send_to(&packet_bytes, remote_addr) + .expect("Failed to send packet"); packet_number += 1; sha512_hasher.update(&file_buffer[..bytes_read]); @@ -242,7 +328,9 @@ fn main() { let mut result_bytes = [0u8; 65]; result_bytes[..64].copy_from_slice(&result[..64]); - socket.send_to(&result_bytes, remote_addr).expect("Failed to send file hash"); + socket + .send_to(&result_bytes, remote_addr) + .expect("Failed to send file hash"); } } } diff --git a/rust/server/src/prime_utils.rs b/rust/server/src/prime_utils.rs index 68d49be..9cbdd23 100644 --- a/rust/server/src/prime_utils.rs +++ b/rust/server/src/prime_utils.rs @@ -1,19 +1,18 @@ - - pub mod prime_utils { - use num::{BigUint, One, Zero}; use lazy_static::lazy_static; + use num::{BigUint, One, Zero}; use crate::primality_test::primality_tests::is_probably_prime; - #[must_use] pub fn log_2(x: &BigUint) -> u64 { + #[must_use] + pub fn log_2(x: &BigUint) -> u64 { x.bits() - 1 } - #[must_use] pub fn is_prime_default(number: &BigUint) -> bool { - + #[must_use] + pub fn is_prime_default(number: &BigUint) -> bool { lazy_static! { - static ref defaultvec: Vec = { + static ref DEFAULTVEC: Vec = { let mut vec = Vec::new(); vec.push(BigUint::from(2u8)); vec.push(BigUint::from(3u8)); @@ -31,73 +30,106 @@ pub mod prime_utils { vec.push(BigUint::from(43u8)); vec.push(BigUint::from(47u8)); vec.push(BigUint::from(53u8)); + vec.push(BigUint::from(59u8)); + vec.push(BigUint::from(61u8)); + vec.push(BigUint::from(67u8)); + vec.push(BigUint::from(71u8)); + vec.push(BigUint::from(73u8)); + vec.push(BigUint::from(79u8)); + vec.push(BigUint::from(83u8)); + vec.push(BigUint::from(89u8)); + vec.push(BigUint::from(97u8)); + vec.push(BigUint::from(101u8)); + vec.push(BigUint::from(103u8)); + vec.push(BigUint::from(107u8)); + vec.push(BigUint::from(109u8)); + vec.push(BigUint::from(113u8)); + vec.push(BigUint::from(127u8)); + vec.push(BigUint::from(131u8)); + vec.push(BigUint::from(137u8)); + vec.push(BigUint::from(139u8)); + vec.push(BigUint::from(149u8)); + vec.push(BigUint::from(151u8)); + vec.push(BigUint::from(157u8)); + vec.push(BigUint::from(163u8)); + vec.push(BigUint::from(167u8)); + vec.push(BigUint::from(173u8)); + vec.push(BigUint::from(179u8)); + vec.push(BigUint::from(181u8)); + vec.push(BigUint::from(191u8)); + vec.push(BigUint::from(193u8)); + vec.push(BigUint::from(197u8)); + vec.push(BigUint::from(199u8)); + vec.push(BigUint::from(211u8)); + vec.push(BigUint::from(223u8)); + vec.push(BigUint::from(227u8)); + vec.push(BigUint::from(229u8)); + vec.push(BigUint::from(233u8)); + vec.push(BigUint::from(239u8)); + vec.push(BigUint::from(241u8)); + vec.push(BigUint::from(251u8)); //all primes that fit in a u8 vec }; } - return is_prime(number, &defaultvec); + return is_prime(number, &DEFAULTVEC); } - #[must_use] pub fn is_prime(number: &BigUint, g_primes: &Vec) -> bool { + #[must_use] + pub fn is_prime(number: &BigUint, g_primes: &Vec) -> bool { if BigUint::from(1u8) == *number { return false; } - if BigUint::from(4u8) > *number { + if BigUint::from(4u8) > *number { return true; } if number.sqrt().pow(2) == *number { return false; } - + let two = BigUint::from(2u8); - + // number = 2^a - 1 // a = log2(number + 1) - let a = log_2(&(number+1u8)); - if BigUint::from(2u8).pow(a as u32)-BigUint::one() != *number { - let mut i = BigUint::one(); + let a = log_2(&(number + 1u8)); + if BigUint::from(2u8).pow(a as u32) - BigUint::one() != *number { + let mut i; let one = BigUint::one(); let zero = BigUint::zero(); - - let sqrtnum = number.sqrt()+&one; //fake ceil function - if let Some(max_value) = g_primes.iter().max() { - if max_value > &sqrtnum { - for prime in g_primes { - if prime<&sqrtnum && number%prime == zero { - return false; - } - } + let sqrtnum = number.sqrt() + &one; //fake ceil function + + for prime in g_primes { + if prime < &sqrtnum && number % prime == zero { + return false; } } + i = g_primes.iter().max().unwrap().clone(); - if !is_probably_prime(number,5) { + if !is_probably_prime(number, 5) { return false; } loop { i += &one; - if number%&i == zero { + if number % &i == zero { return false; } if i == sqrtnum { return true; } } - - } - + // 4 12 194 let mut last = BigUint::from(4u8); - + for _i in 2..a { - last = (last.pow(2)-&two)%number; + last = (last.pow(2) - &two) % number; } last == BigUint::from(0u8) } } -