diff --git a/Cargo.lock b/Cargo.lock index af3045b..65a5184 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,12 +11,66 @@ dependencies = [ "serde", ] +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "byteorder" version = "1.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "cpufeatures" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e4c1eaa2012c47becbbad2ab175484c2a84d1185b566fb2cc5b8707343dfe58" +dependencies = [ + "libc", +] + +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "itransfer" version = "0.1.0" @@ -24,8 +78,25 @@ dependencies = [ "bincode", "byteorder", "serde", + "sha2", + "sha3", ] +[[package]] +name = "keccak" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f6d5ed8676d904364de097082f4e7d240b571b67989ced0240f08b7f966f940" +dependencies = [ + "cpufeatures", +] + +[[package]] +name = "libc" +version = "0.2.144" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b00cc1c228a6782d0f076e7b232802e0c5689d41bb5df366f2a6b6621cfdfe1" + [[package]] name = "proc-macro2" version = "1.0.58" @@ -64,6 +135,27 @@ dependencies = [ "syn", ] +[[package]] +name = "sha2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82e6b795fe2e3b1e845bafcb27aa35405c4d47cdfc92af5fc8d3002f76cebdc0" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "sha3" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" +dependencies = [ + "digest", + "keccak", +] + [[package]] name = "syn" version = "2.0.16" @@ -75,8 +167,20 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "typenum" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" + [[package]] name = "unicode-ident" version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4" + +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" diff --git a/Cargo.toml b/Cargo.toml index 7ed66c9..b285aee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,4 +9,5 @@ edition = "2021" bincode = "1.3.3" serde = { version = "1.0.163", features = ["derive"] } byteorder = "1.4.3" - +sha3 = "0.10.8" +sha2 = "0.10.6" diff --git a/generate_data.py b/generate_data.py index 166acca..c9f02f4 100644 --- a/generate_data.py +++ b/generate_data.py @@ -1,6 +1,6 @@ import secrets -rndsize = 488*2 +rndsize = 488*100+5 rndbytes = secrets.token_bytes(rndsize) diff --git a/server.py b/server.py index 36b6234..3146ed4 100644 --- a/server.py +++ b/server.py @@ -1,6 +1,6 @@ import socket, struct, os, select, math, hashlib, re#, brotli -max_frame_payload=1016 #508 is the minimum size for routers to "understand" udp packets, but we can possibly send more data +max_frame_payload=508 #1016 #508 is the minimum size for routers to "understand" udp packets, but we can possibly send more data header = "!L16s" header_size = struct.calcsize(header) max_payload = max_frame_payload - header_size diff --git a/src/big_array.rs b/src/big_array.rs new file mode 100644 index 0000000..f586968 --- /dev/null +++ b/src/big_array.rs @@ -0,0 +1,68 @@ +use std::fmt; +use std::marker::PhantomData; +use serde::ser::{Serialize, Serializer, SerializeTuple}; +use serde::de::{Deserialize, Deserializer, Visitor, SeqAccess, Error}; + +pub trait BigArray<'de>: Sized { + fn serialize(&self, serializer: S) -> Result + where S: Serializer; + fn deserialize(deserializer: D) -> Result + where D: Deserializer<'de>; +} + +macro_rules! big_array { + ($($len:expr,)+) => { + $( + impl<'de, T> BigArray<'de> for [T; $len] + where T: Default + Copy + Serialize + Deserialize<'de> + { + fn serialize(&self, serializer: S) -> Result + where S: Serializer + { + let mut seq = serializer.serialize_tuple(self.len())?; + for elem in &self[..] { + seq.serialize_element(elem)?; + } + seq.end() + } + + fn deserialize(deserializer: D) -> Result<[T; $len], D::Error> + where D: Deserializer<'de> + { + struct ArrayVisitor { + element: PhantomData, + } + + impl<'de, T> Visitor<'de> for ArrayVisitor + where T: Default + Copy + Deserialize<'de> + { + type Value = [T; $len]; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str(concat!("an array of length ", $len)) + } + + fn visit_seq(self, mut seq: A) -> Result<[T; $len], A::Error> + where A: SeqAccess<'de> + { + let mut arr = [T::default(); $len]; + for i in 0..$len { + arr[i] = seq.next_element()? + .ok_or_else(|| Error::invalid_length(i, &self))?; + } + Ok(arr) + } + } + + let visitor = ArrayVisitor { element: PhantomData }; + deserializer.deserialize_tuple($len, visitor) + } + } + )+ + } +} + +big_array! { + 40, 48, 50, 56, 64, 72, 96, 100, 128, 160, 192, 200, 224, 256, 384, 488, 512, + 768, 1024, 2048, 4096, 8192, 16384, 32768, 65536, +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index fe94194..132ec5a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,11 @@ use bincode::{self, Error, options, Options}; -use std::net::{UdpSocket, SocketAddr}; +use std::{net::{UdpSocket, SocketAddr}, io::Write}; use serde::{Serialize, Deserialize}; +use sha3::{Digest, Sha3_512}; +use sha2::Sha512; + +mod big_array; +use big_array::BigArray; const MAX_FRAME_PAYLOAD:u16=508; const MAX_FRAME_PAYLOAD_U:usize=MAX_FRAME_PAYLOAD as usize; @@ -15,10 +20,12 @@ struct PacketInformation { response_filename_checksum: u64, //8 bytes } //14 bytes -struct packet { +#[derive(Serialize, Deserialize)] +struct Packet { packet_number: u32, //4 bytes - payload_hash: u128, //16 bytes - payload: [u8; MAX_PAYLOAD_U], //508 bytes + payload_hash: [u8; 16], //16 bytes + #[serde(with = "BigArray")] + payload: [u8; MAX_PAYLOAD_U], //488 bytes } //512 bytes fn main() { @@ -28,8 +35,8 @@ fn main() { let filename = "data.bin"; - let local_addr: SocketAddr = "0.0.0.0:26000".parse().expect("Failed to parse address"); - let socket = UdpSocket::bind(local_addr).expect("Failed to bind socket"); + let local_addr: SocketAddr = "0.0.0.0:0".parse().expect("Failed to parse address"); + let socket = UdpSocket::bind(local_addr).expect("Failed to bind socket to port 26000"); socket.set_read_timeout(Some(std::time::Duration::new(timeout, 0))).expect("set_read_timeout call failed"); //socket.set_nonblocking(true).expect("set_nonblocking call failed"); @@ -74,28 +81,137 @@ fn main() { } //create vector to store the packets - let mut packets: Vec = Vec::new(); + let mut packets: Vec> = Vec::new(); for _ in 0..packet_info.packet_numbers { - packets.push(0); + packets.push(Vec::new()); } - let received_packets = 0; + let mut received_packets = 0; + + let mut server_hash = [0u8; 64]; + let mut server_hash_received = false; //receive the packets - while received_packets < packet_info.packet_numbers { - let mut buffer = [0u8; MAX_PAYLOAD_U]; - let (received_bytes, remote_addr) = socket.recv_from(&mut buffer).expect("Failed to receive data"); + while received_packets < packet_info.packet_numbers-1 || !server_hash_received { + let mut buffer = [0u8; MAX_FRAME_PAYLOAD_U]; + let res = socket.recv_from(&mut buffer); + if let Ok((received_bytes, remote_addr)) = res { + let filled_buffer = &buffer;//[..received_bytes]; - let filled_buffer = &buffer[..received_bytes]; + //print the filled buffer + //println!("Received data: {:?} {}", filled_buffer, filled_buffer.len()); - //print the filled buffer - println!("Received data: {:?} {}", filled_buffer, filled_buffer.len()); + if remote_addr != server_addr { + panic!("Received data from unknown address"); + } - if remote_addr != server_addr { - panic!("Received data from unknown address"); + if received_bytes == 65 { + //println!("Received hash packet, ignoring for now"); + server_hash = filled_buffer[0..64].try_into().expect("Failed to convert hash"); + //println!("Received hash: {:?}", server_hash); + server_hash_received = true; + continue; + } + + if received_bytes != MAX_FRAME_PAYLOAD_U { + println!("Received packet with invalid size {} not {} | ignoring", received_bytes, MAX_FRAME_PAYLOAD_U); + continue; + } + + let packet_result: Result = options.deserialize(filled_buffer); + let packet: Packet; + match packet_result { + Ok(p) => { + //println!("Packet {}", p.packet_number); + //check checksum with sum + if p.packet_number != packet_info.packet_numbers - 1 { + let mut hasher = Sha3_512::new(); + hasher.update(&p.payload); + let hash = hasher.finalize(); + + if p.payload_hash != hash[..16] { + continue; + } + + packet = p; + } else { + //cut packet down to size + packets[p.packet_number as usize] = p.payload[..packet_info.last_packet_size as usize].to_vec(); + continue; + } + } + Err(err) => { + panic!("Failed to deserialize data: {}", err); + } + } + + if packets[packet.packet_number as usize].len() != 0 { + //println!("Packet already received, ignoring"); + continue; + } + + packets[packet.packet_number as usize] = packet.payload.to_vec(); + received_packets += 1; + } else { + //println!("Timeout, requesting again {}/{}\r", received_packets, packet_info.packet_numbers); + //collect packets that were not received and send them in n messages + //where n is the minimum amount of messages needed to request all packets + let mut missing_packets: Vec = Vec::new(); + for i in 0..packet_info.packet_numbers { + if packets[i as usize].len() == 0 { + missing_packets.push(i); + } + } + + //split lost_packets into groups of size 508-filename.len() bytes + let mut missing_packet_groups: Vec = Vec::new(); + let mut current_group: String = filename.to_string(); + for i in 0..missing_packets.len() { + if current_group.len() + missing_packets[i].to_string().len() + 1 > MAX_PAYLOAD_U { + missing_packet_groups.push(current_group); + current_group = filename.to_string(); + } + current_group.push('/'); + current_group.push_str(&missing_packets[i].to_string()); + } + + if current_group.len() > filename.len() { + missing_packet_groups.push(current_group); + } + + for i in 0..missing_packet_groups.len() { + let message = &missing_packet_groups[i]; + //println!("Requesting packets: {}", message); + socket.send_to(message.as_bytes(), server_addr).expect("Failed to send data"); + } + + if !server_hash_received { + let message = filename.to_string()+":"; + socket.send_to(message.as_bytes(), server_addr).expect("Failed to send data"); + } } - + print!("Packet {}/{}\r", received_packets, packet_info.packet_numbers); + } + + //check hash via sha512 + + let mut hasher = Sha512::new(); + for i in 0..packets.len() { + hasher.update(&packets[i]); + } + let client_hash = hasher.finalize(); + + if client_hash[..64] != server_hash { + panic!("Hashes do not match, correct hash: {:?}, client hash: {:?}", server_hash, client_hash); + } + + + println!("Received all packets, writing to file"); + //write packets to file + let mut file = std::fs::File::create("received/".to_string()+filename).expect("Failed to create file"); + for i in 0..packets.len() { + file.write_all(&packets[i]).expect("Failed to write to file"); } }