358 lines
12 KiB
Python
358 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import random
|
|
import pickle
|
|
import logging
|
|
import math
|
|
import bisect
|
|
from enum import Enum
|
|
|
|
from msg import StringNetMsg
|
|
from server import Server
|
|
|
|
|
|
# Network parameters
|
|
|
|
# On average, how large is a consensus diff as compared to a full
|
|
# consensus?
|
|
P_Delta = 0.019
|
|
|
|
class WOMode(Enum):
|
|
"""The different Walking Onion modes"""
|
|
VANILLA = 0 # No Walking Onions
|
|
TELESCOPING = 1 # Telescoping Walking Onions
|
|
SINGLEPASS = 2 # Single-Pass Walking Onions
|
|
|
|
def string_to_type(self, type_input):
|
|
reprs = {'vanilla': WOMode.VANILLA, 'telescoping': WOMode.TELESCOPING,
|
|
'single-pass': WOMode.SINGLEPASS }
|
|
|
|
if type_input in reprs.keys():
|
|
return reprs[type_input]
|
|
|
|
return -1
|
|
|
|
|
|
class SNIPAuthMode(Enum):
|
|
"""The different styles of SNIP authentication"""
|
|
NONE = 0 # No SNIPs; only used for WOMode = VANILLA
|
|
MERKLE = 1 # Merkle trees
|
|
THRESHSIG = 2 # Threshold signatures
|
|
|
|
# We only need to differentiate between merkle and telescoping on the
|
|
# command line input, Vanilla always takes a NONE type but nothing else
|
|
# does.
|
|
def string_to_type(self, type_input):
|
|
reprs = {'merkle': SNIPAuthMode.MERKLE,
|
|
'telesocping': SNIPAuthMode.THRESHSIG }
|
|
|
|
if type_input in reprs.keys():
|
|
return reprs[type_input]
|
|
|
|
return -1
|
|
|
|
|
|
class EntType(Enum):
|
|
"""The different types of entities in the system."""
|
|
NONE = 0
|
|
DIRAUTH = 1
|
|
RELAY = 2
|
|
CLIENT = 3
|
|
|
|
|
|
class PerfStats:
|
|
"""A class to store performance statistics for a relay or client.
|
|
We keep track of bytes sent, bytes received, and counts of
|
|
public-key operations of various types. We will reset these every
|
|
epoch."""
|
|
|
|
def __init__(self, ent_type, bw=None):
|
|
# Which type of entity is this for (DIRAUTH, RELAY, CLIENT)
|
|
self.ent_type = ent_type
|
|
# A printable name for the entity
|
|
self.name = None
|
|
# The relay bandwidth, if appropriate
|
|
self.bw = bw
|
|
# True if bootstrapping this epoch
|
|
self.is_bootstrapping = False
|
|
self.circuit_building_time = 0
|
|
self.reset()
|
|
|
|
def __str__(self):
|
|
return "%s: type=%s boot=%s sent=%s recv=%s keygen=%d sig=%d verif=%d dh=%d circuit_building_time=%d" % \
|
|
(self.name, self.ent_type.name, self.is_bootstrapping, \
|
|
self.bytes_sent, self.bytes_received, self.keygens, \
|
|
self.sigs, self.verifs, self.dhs, self.circuit_building_time)
|
|
|
|
def reset(self):
|
|
"""Reset the counters, typically at the beginning of each
|
|
epoch."""
|
|
# Bytes sent and received
|
|
self.bytes_sent = 0
|
|
self.bytes_received = 0
|
|
# Public-key operations: key generation, signing, verification,
|
|
# Diffie-Hellman
|
|
self.keygens = 0
|
|
self.sigs = 0
|
|
self.verifs = 0
|
|
self.dhs = 0
|
|
# Circuit building
|
|
self.circuit_building_time = 0
|
|
|
|
|
|
class PerfStatsStats:
|
|
"""Accumulate a number of PerfStats objects to compute the means and
|
|
stddevs of their fields."""
|
|
|
|
class SingleStat:
|
|
"""Accumulate single numbers to compute their mean and
|
|
stddev."""
|
|
|
|
def __init__(self):
|
|
self.tot = 0
|
|
self.totsq = 0
|
|
self.N = 0
|
|
|
|
def accum(self, x):
|
|
self.tot += x
|
|
self.totsq += x*x
|
|
self.N += 1
|
|
|
|
def __str__(self):
|
|
mean = self.tot/self.N
|
|
if self.N > 1:
|
|
stddev = math.sqrt((self.totsq - self.tot*self.tot/self.N) \
|
|
/ (self.N - 1))
|
|
return "%f \pm %f" % (mean, stddev)
|
|
else:
|
|
return "%f" % mean
|
|
|
|
def __init__(self, usebw=False):
|
|
self.usebw = usebw
|
|
self.bytes_sent = PerfStatsStats.SingleStat()
|
|
self.bytes_received = PerfStatsStats.SingleStat()
|
|
self.bytes_tot = PerfStatsStats.SingleStat()
|
|
if self.usebw:
|
|
self.bytesperbw_sent = PerfStatsStats.SingleStat()
|
|
self.bytesperbw_received = PerfStatsStats.SingleStat()
|
|
self.bytesperbw_tot = PerfStatsStats.SingleStat()
|
|
self.keygens = PerfStatsStats.SingleStat()
|
|
self.sigs = PerfStatsStats.SingleStat()
|
|
self.verifs = PerfStatsStats.SingleStat()
|
|
self.dhs = PerfStatsStats.SingleStat()
|
|
self.circuit_building_time = PerfStatsStats.SingleStat()
|
|
self.N = 0
|
|
|
|
def accum(self, stat):
|
|
self.bytes_sent.accum(stat.bytes_sent)
|
|
self.bytes_received.accum(stat.bytes_received)
|
|
self.bytes_tot.accum(stat.bytes_sent + stat.bytes_received)
|
|
if self.usebw:
|
|
self.bytesperbw_sent.accum(stat.bytes_sent/stat.bw)
|
|
self.bytesperbw_received.accum(stat.bytes_received/stat.bw)
|
|
self.bytesperbw_tot.accum((stat.bytes_sent + stat.bytes_received)/stat.bw)
|
|
self.keygens.accum(stat.keygens)
|
|
self.sigs.accum(stat.sigs)
|
|
self.verifs.accum(stat.verifs)
|
|
self.dhs.accum(stat.dhs)
|
|
self.circuit_building_time.accum(stat.circuit_building_time)
|
|
self.N += 1
|
|
|
|
def __str__(self):
|
|
if self.N > 0:
|
|
if self.usebw:
|
|
return "sent=%s recv=%s bytes=%s sentperbw=%s recvperbw=%s bytesperbw=%s keygen=%s sig=%s verif=%s dh=%s circuit_building_time=%s N=%s" % \
|
|
(self.bytes_sent, self.bytes_received, self.bytes_tot,
|
|
self.bytesperbw_sent, self.bytesperbw_received,
|
|
self.bytesperbw_tot,
|
|
self.keygens, self.sigs, self.verifs, self.dhs, self.circuit_building_time, self.N)
|
|
else:
|
|
return "sent=%s recv=%s bytes=%s keygen=%s sig=%s verif=%s dh=%s circuit_building_time=%s N=%s" % \
|
|
(self.bytes_sent, self.bytes_received, self.bytes_tot,
|
|
self.keygens, self.sigs, self.verifs, self.dhs, self.circuit_building_time, self.N)
|
|
else:
|
|
return "N=0"
|
|
|
|
|
|
class NetAddr:
|
|
"""A class representing a network address"""
|
|
nextaddr = 1
|
|
|
|
def __init__(self):
|
|
"""Generate a fresh network address"""
|
|
self.addr = NetAddr.nextaddr
|
|
NetAddr.nextaddr += 1
|
|
|
|
def __eq__(self, other):
|
|
return (isinstance(other, self.__class__)
|
|
and self.__dict__ == other.__dict__)
|
|
|
|
def __hash__(self):
|
|
return hash(self.addr)
|
|
|
|
def __str__(self):
|
|
return self.addr.__str__()
|
|
|
|
|
|
class NetNoServer(Exception):
|
|
"""No server is listening on the address someone tried to connect
|
|
to."""
|
|
|
|
|
|
class Network:
|
|
"""A class representing a simulated network. Servers can bind()
|
|
to the network, yielding a NetAddr (network address), and clients
|
|
can connect() to a NetAddr yielding a Connection."""
|
|
|
|
def __init__(self):
|
|
self.servers = dict()
|
|
self.epoch = 1
|
|
self.epochprioritycallbacks = []
|
|
self.epochpriorityendingcallbacks = []
|
|
self.epochcallbacks = []
|
|
self.epochendingcallbacks = []
|
|
self.dirauthkeylist = []
|
|
self.fallbackrelays = []
|
|
self.womode = WOMode.VANILLA
|
|
self.snipauthmode = SNIPAuthMode.NONE
|
|
|
|
def printservers(self):
|
|
"""Print the list of NetAddrs bound to something."""
|
|
print("Servers:")
|
|
for a in self.servers.keys():
|
|
print(a)
|
|
|
|
def setdirauthkey(self, index, vk):
|
|
"""Set the public verification key for dirauth number index to
|
|
vk."""
|
|
if index >= len(self.dirauthkeylist):
|
|
self.dirauthkeylist.extend([None] * (index+1-len(self.dirauthkeylist)))
|
|
self.dirauthkeylist[index] = vk
|
|
|
|
def dirauthkeys(self):
|
|
"""Return the list of dirauth public verification keys."""
|
|
return self.dirauthkeylist
|
|
|
|
def getepoch(self):
|
|
"""Return the current epoch."""
|
|
return self.epoch
|
|
|
|
def nextepoch(self):
|
|
"""Increment the current epoch, and return it."""
|
|
logging.info("Ending epoch %s", self.epoch)
|
|
totendingcallbacks = len(self.epochpriorityendingcallbacks) + \
|
|
len(self.epochendingcallbacks)
|
|
numendingcalled = 0
|
|
lastroundpercent = -1
|
|
for l in [ self.epochpriorityendingcallbacks,
|
|
self.epochendingcallbacks ]:
|
|
for c in l:
|
|
c.epoch_ending(self.epoch)
|
|
numendingcalled += 1
|
|
roundpercent = int(100*numendingcalled/totendingcallbacks)
|
|
if roundpercent != lastroundpercent:
|
|
logging.info("Ending epoch %s %d%% complete",
|
|
self.epoch, roundpercent)
|
|
lastroundpercent = roundpercent
|
|
self.epoch += 1
|
|
logging.info("Starting epoch %s", self.epoch)
|
|
totcallbacks = len(self.epochprioritycallbacks) + \
|
|
len(self.epochcallbacks)
|
|
numcalled = 0
|
|
lastroundpercent = -1
|
|
for l in [ self.epochprioritycallbacks, self.epochcallbacks ]:
|
|
for c in l:
|
|
c.newepoch(self.epoch)
|
|
numcalled += 1
|
|
roundpercent = int(100*numcalled/totcallbacks)
|
|
if roundpercent != lastroundpercent:
|
|
logging.info("Starting epoch %s %d%% complete",
|
|
self.epoch, roundpercent)
|
|
lastroundpercent = roundpercent
|
|
logging.info("Epoch %s started", self.epoch)
|
|
return self.epoch
|
|
|
|
def wantepochticks(self, callback, want, priority=False, end=False):
|
|
"""Register or deregister an object from receiving epoch change
|
|
callbacks. If want is True, the callback object's newepoch()
|
|
method will be called at each epoch change, with an argument of
|
|
the new epoch. If want if False, the callback object will be
|
|
deregistered. If priority is True, call back this object before
|
|
any object with priority=False. If end is True, the callback
|
|
object's epoch_ending() method will be called instead at the end
|
|
of the epoch, just _before_ the epoch number change."""
|
|
if end:
|
|
if priority:
|
|
l = self.epochpriorityendingcallbacks
|
|
else:
|
|
l = self.epochendingcallbacks
|
|
else:
|
|
if priority:
|
|
l = self.epochprioritycallbacks
|
|
else:
|
|
l = self.epochcallbacks
|
|
|
|
if want:
|
|
l.append(callback)
|
|
else:
|
|
l.remove(callback)
|
|
|
|
def bind(self, server):
|
|
"""Bind a server to a newly generated NetAddr, returning the
|
|
NetAddr. The server's bound() callback will also be invoked."""
|
|
addr = NetAddr()
|
|
self.servers[addr] = server
|
|
server.bind(addr, lambda: self.servers.pop(addr))
|
|
return addr
|
|
|
|
def connect(self, client, srvaddr, perfstats):
|
|
"""Connect the given client to the server bound to addr. Throw
|
|
an exception if there is no server bound to that address."""
|
|
try:
|
|
server = self.servers[srvaddr]
|
|
except KeyError:
|
|
raise NetNoServer()
|
|
conn = server.connected(client)
|
|
conn.perfstats = perfstats
|
|
return conn
|
|
|
|
def setfallbackrelays(self, fallbackrelays):
|
|
"""Set the list of globally known fallback relays. Clients use
|
|
these to bootstrap when they know no other relays."""
|
|
self.fallbackrelays = fallbackrelays
|
|
|
|
# Construct the CDF of fallback relay bws, so that clients can
|
|
# choose a fallback relay weighted by bw
|
|
self.fallbackbwcdf = [0]
|
|
for r in fallbackrelays:
|
|
self.fallbackbwcdf.append(self.fallbackbwcdf[-1]+r.bw)
|
|
# Remove the last item, which should be the sum of all the
|
|
# relays
|
|
self.fallbacktotbw = self.fallbackbwcdf.pop()
|
|
|
|
def getfallbackrelay(self):
|
|
"""Get a random one of the globally known fallback relays,
|
|
weighted by bw. Clients use these to bootstrap when they know
|
|
no other relays."""
|
|
idx = random.randint(0, self.fallbacktotbw-1)
|
|
i = bisect.bisect_right(self.fallbackbwcdf, idx)
|
|
r = self.fallbackrelays[i-1]
|
|
return r
|
|
|
|
def set_wo_style(self, womode, snipauthmode):
|
|
"""Set the Walking Onions mode and the SNIP authenticate mode
|
|
for the network."""
|
|
if ((womode == WOMode.VANILLA) \
|
|
and (snipauthmode != SNIPAuthMode.NONE)) or \
|
|
((womode != WOMode.VANILLA) and \
|
|
(snipauthmode == SNIPAuthMode.NONE)):
|
|
# Incompatible settings
|
|
raise ValueError("Bad argument combination")
|
|
|
|
self.womode = womode
|
|
self.snipauthmode = snipauthmode
|
|
|
|
|
|
# The singleton instance of Network
|
|
thenetwork = Network()
|