ITransfer/client.py
2023-05-27 22:15:51 +02:00

170 lines
4.9 KiB
Python

timeout = 0.5
local_server = False
serverAddressPort = ("213.47.107.152", 1337)
if local_server:
serverAddressPort[0] = "127.0.0.1"
possible_value = input("Enter a value: ")
filename = possible_value if (filename := possible_value) else "data.bin"
import socket, struct, select, time, math
import hashlib
#convert bps to Kbps, Mbps, Gbps if they are >= 1000
def convert_to_highest_speed(bps):
if bps < 1000:
return f'{round(bps,2)}bps'
elif bps < 1000**2:
return f'{round(bps/1000,2)}Kbps'
elif bps < 1000**3:
return f'{round(bps/1000**2,2)}Mbps'
elif bps < 1000**4:
return f'{round(bps/1000**3,2)}Gbps'
max_frame_payload=1016 #508 is the minimum size for routers to "understand" udp packets, but we can possibly send more data
header = "!L16s"
header_size = struct.calcsize(header)
max_payload = max_frame_payload - header_size
UDPClientSocket = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
message = filename
UDPClientSocket.sendto(str.encode(message), serverAddressPort)
ready = select.select([UDPClientSocket], [], [], timeout)
if not ready[0]:
print("Server not responding")
exit()
msgFromServer = UDPClientSocket.recvfrom(max_frame_payload)[0]
msg = struct.unpack(f'!IHQ', msgFromServer)
last_packet_size = msg[1]
if msg[2] != sum(byte for byte in message.encode('utf-8')):
print('Server error: wrong file name')
exit()
total_packets = msg[0]
packets = msg[0]
print(f'Packets to receive: {packets}')
total_data = b''
data_packets = []
for i in range(packets):
data_packets.append(None)
sha512 = hashlib.sha512()
serverhash = None
time_started = time.time()
while packets > 0 or serverhash == None:
ready = select.select([UDPClientSocket], [], [], timeout)
if ready[0]:
datareceived = UDPClientSocket.recvfrom(max_frame_payload)[0]
if len(datareceived) == 65:
serverhash = struct.unpack("64sx", datareceived)[0]
continue
if len(datareceived) != max_frame_payload:
print("\n[WARNING] bad packet: wrong size",len(datareceived))
continue
seq,checksum,data = struct.unpack(f"{header}{max_payload}s", datareceived)
if data_packets[seq] != None:
print("\n[INFO] bad packet: received multiple times",seq)
continue
#check if packet is last packet and cut data to last packet size
if seq == total_packets-1:
#print("len before cut",len(data))
data = data[:last_packet_size]
# print("len after cut",len(data))
#check if packet is corrupted
if checksum.hex() != hashlib.sha3_512(data).digest()[:16].hex():
print("\n[WARNING] bad packet: checksum mismatch",seq)
print(checksum.hex(),hashlib.sha3_512(data).digest()[:16].hex())
print(data.hex())
continue
#print(seq)
#print(data)
#data = deflate.gzip_decompress(data)
data_packets[seq] = data
packets -= 1
else:
#collect packets that were not received and send them in n messages
#where n is the minimum amount of messages needed to request all packets
msg = message
lost_packets = []
for i in range(len(data_packets)):
if data_packets[i] == None:
lost_packets.append(i)
#split lost_packets into groups of size 508-msg.length bytes
current_packets = []
for i in range(len(lost_packets)):
if len(str(lost_packets[i])) + len(msg) + 1 > max_frame_payload:
current_packets.append(msg)
msg = message
msg += f"/{lost_packets[i]}"
if len(lost_packets) > 0:
timeout *= 2 #increase timeout to relieve internet connection and server
current_packets.append(msg)
for msg in current_packets:
UDPClientSocket.sendto(str.encode(msg), serverAddressPort)
#print(f'requested lost packets {msg}')
if serverhash == None:
UDPClientSocket.sendto(str.encode(f"{message}:"), serverAddressPort)
download_speed = (total_packets - packets)/(time.time()-time_started if (time.time()-time_started)!=0 else 1)*max_payload
print(f'{total_packets - packets}/{total_packets} | {round((total_packets - packets)/total_packets*100,2)}% | {convert_to_highest_speed(download_speed)}'+'\t'*5, end='\r')
print("finished transfer" + "\t"*5, end='\r')
for packet in data_packets:
sha512.update(packet)
sha512 = sha512.hexdigest()
total_data = b''.join(data_packets)
if sha512 != serverhash.hex():
print('sha512: %s' % sha512)
print('server sha512: %s' % serverhash.hex())
print(f'hash mismatch, writing corrupt file to disk: {filename}.corrupt')
with open("received/"+filename+".corrupt", 'wb') as f:
f.write(total_data)
exit()
with open("received/"+filename, 'wb') as f:
f.write(total_data)
print(f'{filename} generated with size {len(total_data)} bytes')