1
0
walkingonions-boosted/network.py
2022-03-17 17:05:34 +01:00

358 lines
12 KiB
Python

#!/usr/bin/env python3
import random
import pickle
import logging
import math
import bisect
from enum import Enum
from msg import StringNetMsg
from server import Server
# Network parameters
# On average, how large is a consensus diff as compared to a full
# consensus?
P_Delta = 0.019
class WOMode(Enum):
"""The different Walking Onion modes"""
VANILLA = 0 # No Walking Onions
TELESCOPING = 1 # Telescoping Walking Onions
SINGLEPASS = 2 # Single-Pass Walking Onions
def string_to_type(self, type_input):
reprs = {'vanilla': WOMode.VANILLA, 'telescoping': WOMode.TELESCOPING,
'single-pass': WOMode.SINGLEPASS }
if type_input in reprs.keys():
return reprs[type_input]
return -1
class SNIPAuthMode(Enum):
"""The different styles of SNIP authentication"""
NONE = 0 # No SNIPs; only used for WOMode = VANILLA
MERKLE = 1 # Merkle trees
THRESHSIG = 2 # Threshold signatures
# We only need to differentiate between merkle and telescoping on the
# command line input, Vanilla always takes a NONE type but nothing else
# does.
def string_to_type(self, type_input):
reprs = {'merkle': SNIPAuthMode.MERKLE,
'telesocping': SNIPAuthMode.THRESHSIG }
if type_input in reprs.keys():
return reprs[type_input]
return -1
class EntType(Enum):
"""The different types of entities in the system."""
NONE = 0
DIRAUTH = 1
RELAY = 2
CLIENT = 3
class PerfStats:
"""A class to store performance statistics for a relay or client.
We keep track of bytes sent, bytes received, and counts of
public-key operations of various types. We will reset these every
epoch."""
def __init__(self, ent_type, bw=None):
# Which type of entity is this for (DIRAUTH, RELAY, CLIENT)
self.ent_type = ent_type
# A printable name for the entity
self.name = None
# The relay bandwidth, if appropriate
self.bw = bw
# True if bootstrapping this epoch
self.is_bootstrapping = False
self.circuit_building_time = 0
self.reset()
def __str__(self):
return "%s: type=%s boot=%s sent=%s recv=%s keygen=%d sig=%d verif=%d dh=%d circuit_building_time=%d" % \
(self.name, self.ent_type.name, self.is_bootstrapping, \
self.bytes_sent, self.bytes_received, self.keygens, \
self.sigs, self.verifs, self.dhs, self.circuit_building_time)
def reset(self):
"""Reset the counters, typically at the beginning of each
epoch."""
# Bytes sent and received
self.bytes_sent = 0
self.bytes_received = 0
# Public-key operations: key generation, signing, verification,
# Diffie-Hellman
self.keygens = 0
self.sigs = 0
self.verifs = 0
self.dhs = 0
# Circuit building
self.circuit_building_time = 0
class PerfStatsStats:
"""Accumulate a number of PerfStats objects to compute the means and
stddevs of their fields."""
class SingleStat:
"""Accumulate single numbers to compute their mean and
stddev."""
def __init__(self):
self.tot = 0
self.totsq = 0
self.N = 0
def accum(self, x):
self.tot += x
self.totsq += x*x
self.N += 1
def __str__(self):
mean = self.tot/self.N
if self.N > 1:
stddev = math.sqrt((self.totsq - self.tot*self.tot/self.N) \
/ (self.N - 1))
return "%f \pm %f" % (mean, stddev)
else:
return "%f" % mean
def __init__(self, usebw=False):
self.usebw = usebw
self.bytes_sent = PerfStatsStats.SingleStat()
self.bytes_received = PerfStatsStats.SingleStat()
self.bytes_tot = PerfStatsStats.SingleStat()
if self.usebw:
self.bytesperbw_sent = PerfStatsStats.SingleStat()
self.bytesperbw_received = PerfStatsStats.SingleStat()
self.bytesperbw_tot = PerfStatsStats.SingleStat()
self.keygens = PerfStatsStats.SingleStat()
self.sigs = PerfStatsStats.SingleStat()
self.verifs = PerfStatsStats.SingleStat()
self.dhs = PerfStatsStats.SingleStat()
self.circuit_building_time = PerfStatsStats.SingleStat()
self.N = 0
def accum(self, stat):
self.bytes_sent.accum(stat.bytes_sent)
self.bytes_received.accum(stat.bytes_received)
self.bytes_tot.accum(stat.bytes_sent + stat.bytes_received)
if self.usebw:
self.bytesperbw_sent.accum(stat.bytes_sent/stat.bw)
self.bytesperbw_received.accum(stat.bytes_received/stat.bw)
self.bytesperbw_tot.accum((stat.bytes_sent + stat.bytes_received)/stat.bw)
self.keygens.accum(stat.keygens)
self.sigs.accum(stat.sigs)
self.verifs.accum(stat.verifs)
self.dhs.accum(stat.dhs)
self.circuit_building_time.accum(stat.circuit_building_time)
self.N += 1
def __str__(self):
if self.N > 0:
if self.usebw:
return "sent=%s recv=%s bytes=%s sentperbw=%s recvperbw=%s bytesperbw=%s keygen=%s sig=%s verif=%s dh=%s circuit_building_time=%s N=%s" % \
(self.bytes_sent, self.bytes_received, self.bytes_tot,
self.bytesperbw_sent, self.bytesperbw_received,
self.bytesperbw_tot,
self.keygens, self.sigs, self.verifs, self.dhs, self.circuit_building_time, self.N)
else:
return "sent=%s recv=%s bytes=%s keygen=%s sig=%s verif=%s dh=%s circuit_building_time=%s N=%s" % \
(self.bytes_sent, self.bytes_received, self.bytes_tot,
self.keygens, self.sigs, self.verifs, self.dhs, self.circuit_building_time, self.N)
else:
return "N=0"
class NetAddr:
"""A class representing a network address"""
nextaddr = 1
def __init__(self):
"""Generate a fresh network address"""
self.addr = NetAddr.nextaddr
NetAddr.nextaddr += 1
def __eq__(self, other):
return (isinstance(other, self.__class__)
and self.__dict__ == other.__dict__)
def __hash__(self):
return hash(self.addr)
def __str__(self):
return self.addr.__str__()
class NetNoServer(Exception):
"""No server is listening on the address someone tried to connect
to."""
class Network:
"""A class representing a simulated network. Servers can bind()
to the network, yielding a NetAddr (network address), and clients
can connect() to a NetAddr yielding a Connection."""
def __init__(self):
self.servers = dict()
self.epoch = 1
self.epochprioritycallbacks = []
self.epochpriorityendingcallbacks = []
self.epochcallbacks = []
self.epochendingcallbacks = []
self.dirauthkeylist = []
self.fallbackrelays = []
self.womode = WOMode.VANILLA
self.snipauthmode = SNIPAuthMode.NONE
def printservers(self):
"""Print the list of NetAddrs bound to something."""
print("Servers:")
for a in self.servers.keys():
print(a)
def setdirauthkey(self, index, vk):
"""Set the public verification key for dirauth number index to
vk."""
if index >= len(self.dirauthkeylist):
self.dirauthkeylist.extend([None] * (index+1-len(self.dirauthkeylist)))
self.dirauthkeylist[index] = vk
def dirauthkeys(self):
"""Return the list of dirauth public verification keys."""
return self.dirauthkeylist
def getepoch(self):
"""Return the current epoch."""
return self.epoch
def nextepoch(self):
"""Increment the current epoch, and return it."""
logging.info("Ending epoch %s", self.epoch)
totendingcallbacks = len(self.epochpriorityendingcallbacks) + \
len(self.epochendingcallbacks)
numendingcalled = 0
lastroundpercent = -1
for l in [ self.epochpriorityendingcallbacks,
self.epochendingcallbacks ]:
for c in l:
c.epoch_ending(self.epoch)
numendingcalled += 1
roundpercent = int(100*numendingcalled/totendingcallbacks)
if roundpercent != lastroundpercent:
logging.info("Ending epoch %s %d%% complete",
self.epoch, roundpercent)
lastroundpercent = roundpercent
self.epoch += 1
logging.info("Starting epoch %s", self.epoch)
totcallbacks = len(self.epochprioritycallbacks) + \
len(self.epochcallbacks)
numcalled = 0
lastroundpercent = -1
for l in [ self.epochprioritycallbacks, self.epochcallbacks ]:
for c in l:
c.newepoch(self.epoch)
numcalled += 1
roundpercent = int(100*numcalled/totcallbacks)
if roundpercent != lastroundpercent:
logging.info("Starting epoch %s %d%% complete",
self.epoch, roundpercent)
lastroundpercent = roundpercent
logging.info("Epoch %s started", self.epoch)
return self.epoch
def wantepochticks(self, callback, want, priority=False, end=False):
"""Register or deregister an object from receiving epoch change
callbacks. If want is True, the callback object's newepoch()
method will be called at each epoch change, with an argument of
the new epoch. If want if False, the callback object will be
deregistered. If priority is True, call back this object before
any object with priority=False. If end is True, the callback
object's epoch_ending() method will be called instead at the end
of the epoch, just _before_ the epoch number change."""
if end:
if priority:
l = self.epochpriorityendingcallbacks
else:
l = self.epochendingcallbacks
else:
if priority:
l = self.epochprioritycallbacks
else:
l = self.epochcallbacks
if want:
l.append(callback)
else:
l.remove(callback)
def bind(self, server):
"""Bind a server to a newly generated NetAddr, returning the
NetAddr. The server's bound() callback will also be invoked."""
addr = NetAddr()
self.servers[addr] = server
server.bind(addr, lambda: self.servers.pop(addr))
return addr
def connect(self, client, srvaddr, perfstats):
"""Connect the given client to the server bound to addr. Throw
an exception if there is no server bound to that address."""
try:
server = self.servers[srvaddr]
except KeyError:
raise NetNoServer()
conn = server.connected(client)
conn.perfstats = perfstats
return conn
def setfallbackrelays(self, fallbackrelays):
"""Set the list of globally known fallback relays. Clients use
these to bootstrap when they know no other relays."""
self.fallbackrelays = fallbackrelays
# Construct the CDF of fallback relay bws, so that clients can
# choose a fallback relay weighted by bw
self.fallbackbwcdf = [0]
for r in fallbackrelays:
self.fallbackbwcdf.append(self.fallbackbwcdf[-1]+r.bw)
# Remove the last item, which should be the sum of all the
# relays
self.fallbacktotbw = self.fallbackbwcdf.pop()
def getfallbackrelay(self):
"""Get a random one of the globally known fallback relays,
weighted by bw. Clients use these to bootstrap when they know
no other relays."""
idx = random.randint(0, self.fallbacktotbw-1)
i = bisect.bisect_right(self.fallbackbwcdf, idx)
r = self.fallbackrelays[i-1]
return r
def set_wo_style(self, womode, snipauthmode):
"""Set the Walking Onions mode and the SNIP authenticate mode
for the network."""
if ((womode == WOMode.VANILLA) \
and (snipauthmode != SNIPAuthMode.NONE)) or \
((womode != WOMode.VANILLA) and \
(snipauthmode == SNIPAuthMode.NONE)):
# Incompatible settings
raise ValueError("Bad argument combination")
self.womode = womode
self.snipauthmode = snipauthmode
# The singleton instance of Network
thenetwork = Network()