1
0
walkingonions-boosted/dirauth.py

517 lines
22 KiB
Python
Raw Normal View History

2022-03-17 16:05:34 +00:00
#!/usr/bin/env python3
import os
import random # For simulation, not cryptography!
import bisect
import math
import logging
import resource
import nacl.encoding
import nacl.signing
import merklelib
import hashlib
import network
import msg as simmsg
from connection import DirAuthConnection
# A relay descriptor is a dict containing:
# epoch: epoch id
# idkey: a public identity key
# onionkey: a public onion key
# addr: a network address
# bw: bandwidth
# flags: relay flags
# pathselkey: a path selection public key (Single-Pass Walking Onions only)
# vrfkey: a VRF public key (Single-Pass Walking Onions only)
# sig: a signature over the above by the idkey
class RelayDescriptor:
def __init__(self, descdict):
self.descdict = descdict
def __str__(self, withsig = True):
res = "RelayDesc [\n"
for k in ["epoch", "idkey", "onionkey", "pathselkey", "addr",
"bw", "flags", "vrfkey", "sig"]:
if k in self.descdict:
if k == "idkey" or k == "onionkey" or k == "pathselkey":
res += " " + k + ": " + nacl.encoding.HexEncoder.encode(self.descdict[k]).decode("ascii") + "\n"
elif k == "sig":
if withsig:
res += " " + k + ": " + nacl.encoding.HexEncoder.encode(self.descdict[k]).decode("ascii") + "\n"
else:
res += " " + k + ": " + str(self.descdict[k]) + "\n"
res += "]\n"
return res
def sign(self, signingkey, perfstats):
serialized = self.__str__(False)
signed = signingkey.sign(serialized.encode("ascii"))
perfstats.sigs += 1
self.descdict["sig"] = signed.signature
@staticmethod
def verify(desc, perfstats):
assert(type(desc) is RelayDescriptor)
serialized = desc.__str__(False)
perfstats.verifs += 1
idkey = nacl.signing.VerifyKey(desc.descdict["idkey"])
idkey.verify(serialized.encode("ascii"), desc.descdict["sig"])
# A SNIP is a dict containing:
# epoch: epoch id
# idkey: a public identity key
# onionkey: a public onion key
# addr: a network address
# flags: relay flags
# pathselkey: a path selection public key (Single-Pass Walking Onions only)
# range: the (lo,hi) values for the index range (lo is inclusive, hi is
# exclusive; that is, x is in the range if lo <= x < hi).
# lo=hi denotes an empty range.
# auth: either a signature from the authorities over the above
# (Threshold signature case) or a Merkle path to the root
# contained in the consensus (Merkle tree case)
#
# Note that the fields of the SNIP are the same as those of the
# RelayDescriptor, except bw and sig are removed, and range and auth are
# added.
class SNIP:
def __init__(self, snipdict):
self.snipdict = snipdict
def __str__(self, withauth = True):
res = "SNIP [\n"
for k in ["epoch", "idkey", "onionkey", "pathselkey", "addr",
"flags", "range", "auth"]:
if k in self.snipdict:
if k == "idkey" or k == "onionkey" or k == "pathselkey":
res += " " + k + ": " + nacl.encoding.HexEncoder.encode(self.snipdict[k]).decode("ascii") + "\n"
elif k == "auth":
if withauth:
if network.thenetwork.snipauthmode == \
network.SNIPAuthMode.THRESHSIG:
res += " " + k + ": " + nacl.encoding.HexEncoder.encode(self.snipdict[k]).decode("ascii") + "\n"
else:
res += " " + k + ": " + str(self.snipdict[k])
else:
res += " " + k + ": " + str(self.snipdict[k]) + "\n"
res += "]\n"
return res
def auth(self, signingkey, perfstats):
if network.thenetwork.snipauthmode == network.SNIPAuthMode.THRESHSIG:
serialized = self.__str__(False)
signed = signingkey.sign(serialized.encode("ascii"))
perfstats.sigs += 1
self.snipdict["auth"] = signed.signature
else:
raise ValueError("Merkle auth not valid for SNIP.auth")
@staticmethod
def verify(snip, consensus, verifykey, perfstats):
if network.thenetwork.snipauthmode == network.SNIPAuthMode.THRESHSIG:
assert(type(snip) is SNIP and type(consensus) is Consensus)
assert(consensus.consdict["epoch"] == snip.snipdict["epoch"])
serialized = snip.__str__(False)
perfstats.verifs += 1
verifykey.verify(serialized.encode("ascii"),
snip.snipdict["auth"])
else:
assert(merklelib.verify_leaf_inclusion(
snip.__str__(False),
[merklelib.AuditNode(p[0], p[1])
for p in snip.snipdict["auth"]],
merklelib.Hasher(), consensus.consdict["merkleroot"]))
# A consensus is a dict containing:
# epoch: epoch id
# numrelays: total number of relays
# totbw: total bandwidth of relays
# merkleroot: the root of the SNIP Merkle tree (Merkle tree auth only)
# relays: list of relay descriptors (Vanilla Onion Routing only)
# sigs: list of signatures from the dirauths
class Consensus:
def __init__(self, epoch, relays):
relays = [ d for d in relays if d.descdict['epoch'] == epoch ]
self.consdict = dict()
self.consdict['epoch'] = epoch
self.consdict['numrelays'] = len(relays)
if network.thenetwork.womode == network.WOMode.VANILLA:
self.consdict['totbw'] = sum([ d.descdict['bw'] for d in relays ])
self.consdict['relays'] = relays
else:
self.consdict['totbw'] = 1<<32
def __str__(self, withsigs = True):
res = "Consensus [\n"
for k in ["epoch", "numrelays", "totbw"]:
if k in self.consdict:
res += " " + k + ": " + str(self.consdict[k]) + "\n"
if network.thenetwork.womode == network.WOMode.VANILLA:
for r in self.consdict['relays']:
res += str(r)
if network.thenetwork.snipauthmode == network.SNIPAuthMode.MERKLE:
for k in ["merkleroot"]:
if k in self.consdict:
res += " " + k + ": " + str(self.consdict[k]) + "\n"
if withsigs and ('sigs' in self.consdict):
for s in self.consdict['sigs']:
res += " sig: " + nacl.encoding.HexEncoder.encode(s).decode("ascii") + "\n"
res += "]\n"
return res
def sign(self, signingkey, index, perfstats):
"""Use the given signing key to sign the consensus, storing the
result in the sigs list at the given index."""
serialized = self.__str__(False)
signed = signingkey.sign(serialized.encode("ascii"))
perfstats.sigs += 1
if 'sigs' not in self.consdict:
self.consdict['sigs'] = []
if index >= len(self.consdict['sigs']):
self.consdict['sigs'].extend([None] * (index+1-len(self.consdict['sigs'])))
self.consdict['sigs'][index] = signed.signature
@staticmethod
def verify(consensus, verifkeylist, perfstats):
"""Use the given list of verification keys to check the
signatures on the consensus. Return the RelayPicker if
successful, or raise an exception otherwise."""
assert(type(consensus) is Consensus)
serialized = consensus.__str__(False)
for i, vk in enumerate(verifkeylist):
perfstats.verifs += 1
vk.verify(serialized.encode("ascii"), consensus.consdict['sigs'][i])
# If we got this far, all is well. Return the RelayPicker.
return RelayPicker.get(consensus)
# An ENDIVE is a dict containing:
# epoch: epoch id
# snips: list of SNIPS (in THRESHSIG mode, these include the auth
# signatures; in MERKLE mode, these do _not_ include auth)
# sigs: list of signatures from the dirauths
class ENDIVE:
def __init__(self, epoch, snips):
snips = [ s for s in snips if s.snipdict['epoch'] == epoch ]
self.enddict = dict()
self.enddict['epoch'] = epoch
self.enddict['snips'] = snips
def __str__(self, withsigs = True):
res = "ENDIVE [\n"
for k in ["epoch"]:
if k in self.enddict:
res += " " + k + ": " + str(self.enddict[k]) + "\n"
for s in self.enddict['snips']:
res += str(s)
if withsigs and ('sigs' in self.enddict):
for s in self.enddict['sigs']:
res += " sig: " + nacl.encoding.HexEncoder.encode(s).decode("ascii") + "\n"
res += "]\n"
return res
def sign(self, signingkey, index, perfstats):
"""Use the given signing key to sign the ENDIVE, storing the
result in the sigs list at the given index."""
serialized = self.__str__(False)
signed = signingkey.sign(serialized.encode("ascii"))
perfstats.sigs += 1
if 'sigs' not in self.enddict:
self.enddict['sigs'] = []
if index >= len(self.enddict['sigs']):
self.enddict['sigs'].extend([None] * (index+1-len(self.enddict['sigs'])))
self.enddict['sigs'][index] = signed.signature
@staticmethod
def verify(endive, consensus, verifkeylist, perfstats):
"""Use the given list of verification keys to check the
signatures on the ENDIVE and consensus. Return the RelayPicker
if successful, or raise an exception otherwise."""
assert(type(endive) is ENDIVE and type(consensus) is Consensus)
serializedcons = consensus.__str__(False)
for i, vk in enumerate(verifkeylist):
perfstats.verifs += 1
vk.verify(serializedcons.encode("ascii"), consensus.consdict['sigs'][i])
serializedend = endive.__str__(False)
for i, vk in enumerate(verifkeylist):
perfstats.verifs += 1
vk.verify(serializedend.encode("ascii"), endive.enddict['sigs'][i])
# If we got this far, all is well. Return the RelayPicker.
return RelayPicker.get(consensus, endive)
class RelayPicker:
"""An instance of this class (which may be a singleton in the
simulation) is returned by the Consensus.verify() and
ENDIVE.verify() methods. It does any necessary precomputation
and/or caching, and exposes a method to select a random bw-weighted
relay, either explicitly specifying a uniform random value, or
letting the choice be done internally."""
# The singleton instance
relaypicker = None
def __init__(self, consensus, endive = None):
self.epoch = consensus.consdict["epoch"]
self.totbw = consensus.consdict["totbw"]
self.consensus = consensus
self.endive = endive
assert(endive is None or endive.enddict["epoch"] == self.epoch)
if network.thenetwork.womode == network.WOMode.VANILLA:
# Create the array of cumulative bandwidth values from a
# consensus. The array (cdf) will have the same length as
# the number of relays in the consensus. cdf[0] = 0, and
# cdf[i] = cdf[i-1] + relay[i-1].bw.
self.cdf = [0]
for r in consensus.consdict['relays']:
self.cdf.append(self.cdf[-1]+r.descdict['bw'])
# Remove the last item, which should be the sum of all the bws
self.cdf.pop()
logging.debug('cdf=%s', self.cdf)
else:
# Note that clients will call this with endive = None
if self.endive is not None:
self.cdf = [ s.snipdict['range'][0] \
for s in self.endive.enddict['snips'] ]
if network.thenetwork.snipauthmode == \
network.SNIPAuthMode.MERKLE:
# Construct the Merkle tree of SNIPs and check the
# root matches the one in the consensus
self.merkletree = merklelib.MerkleTree(
[snip.__str__(False) \
for snip in DirAuth.endive.enddict['snips']],
merklelib.Hasher())
assert(self.consensus.consdict["merkleroot"] == \
self.merkletree.merkle_root)
else:
self.cdf = None
logging.debug('cdf=%s', self.cdf)
@staticmethod
def get(consensus, endive = None):
# Return the singleton instance, if it exists for this epoch
# However, don't use the cached instance if that one has
# endive=None, but we were passed a real ENDIVE
if RelayPicker.relaypicker is not None and \
(RelayPicker.relaypicker.endive is not None or \
endive is None) and \
RelayPicker.relaypicker.epoch == consensus.consdict["epoch"]:
return RelayPicker.relaypicker
# Create it otherwise, storing the result as the singleton
RelayPicker.relaypicker = RelayPicker(consensus, endive)
return RelayPicker.relaypicker
def pick_relay_by_uniform_index(self, idx):
"""Pass in a uniform random index random(0,totbw-1) to get a
relay's descriptor or snip (depending on the network mode) selected weighted by bw."""
if network.thenetwork.womode == network.WOMode.VANILLA:
relays = self.consensus.consdict['relays']
else:
relays = self.endive.enddict['snips']
# Find the rightmost entry less than or equal to idx
i = bisect.bisect_right(self.cdf, idx)
r = relays[i-1]
if network.thenetwork.snipauthmode == \
network.SNIPAuthMode.MERKLE:
# If we haven't yet computed the Merkle path for this SNIP,
# do it now, and store it in the SNIP so that the client
# will get it.
if "auth" not in r.snipdict:
r.snipdict["auth"] = [ (p.hash, p.type) for p in \
self.merkletree.get_proof(r.__str__(False))._nodes]
return r
def pick_weighted_relay(self):
"""Select a random relay with probability proportional to its bw
weight."""
idx = self.pick_weighted_relay_index()
return self.pick_relay_by_uniform_index(idx)
def pick_weighted_relay_index(self):
"""Select a random relay index (for use in Walking Onions)
uniformly, which will results in picking a relay with
probability proportional to its bw weight."""
totbw = self.totbw
if totbw < 1:
raise ValueError("No relays to choose from")
return random.randint(0, totbw-1)
class DirAuth(network.Server):
"""The class representing directory authorities."""
# We simulate the act of computing the consensus by keeping a
# class-static dict that's accessible to all of the dirauths
# This dict is indexed by epoch, and the value is itself a dict
# indexed by the stringified descriptor, with value a pair of (the
# number of dirauths that saw that descriptor, the descriptor
# itself).
uploadeddescs = dict()
consensus = None
endive = None
def __init__(self, me, tot):
"""Create a new directory authority. me is the index of which
dirauth this one is (starting from 0), and tot is the total
number of dirauths."""
self.me = me
self.tot = tot
self.name = "Dirauth %d of %d" % (me+1, tot)
self.perfstats = network.PerfStats(network.EntType.DIRAUTH)
self.perfstats.is_bootstrapping = True
# Create the dirauth signature keypair
self.sigkey = nacl.signing.SigningKey.generate()
self.perfstats.keygens += 1
self.netaddr = network.thenetwork.bind(self)
self.perfstats.name = "DirAuth at %s" % self.netaddr
network.thenetwork.setdirauthkey(self.me, self.sigkey.verify_key)
network.thenetwork.wantepochticks(self, True, True, True)
def connected(self, client):
"""Callback invoked when a client connects to us. This callback
creates the DirAuthConnection that will be passed to the
client."""
# We don't actually need to keep per-connection state at
# dirauths, even in long-lived connections, so this is
# particularly simple.
return DirAuthConnection(self)
def generate_consensus(self, epoch):
"""Generate the consensus (and ENDIVE, if using Walking Onions)
for the given epoch, which should be the one after the one
that's currently about to end."""
threshold = int(self.tot/2)+1
consensusdescs = []
for numseen, desc in DirAuth.uploadeddescs[epoch].values():
if numseen >= threshold:
consensusdescs.append(desc)
DirAuth.consensus = Consensus(epoch, consensusdescs)
if network.thenetwork.womode != network.WOMode.VANILLA:
totbw = sum([ d.descdict["bw"] for d in consensusdescs ])
hi = 0
cumbw = 0
snips = []
for d in consensusdescs:
cumbw += d.descdict["bw"]
lo = hi
hi = int((cumbw<<32)/totbw)
snipdict = dict(d.descdict)
del snipdict["bw"]
snipdict["range"] = (lo,hi)
snips.append(SNIP(snipdict))
DirAuth.endive = ENDIVE(epoch, snips)
def epoch_ending(self, epoch):
# Only dirauth 0 actually needs to generate the consensus
# because of the shared class-static state, but everyone has to
# sign it. Note that this code relies on dirauth 0's
# epoch_ending callback being called before any of the other
# dirauths'.
if (epoch+1) not in DirAuth.uploadeddescs:
DirAuth.uploadeddescs[epoch+1] = dict()
if self.me == 0:
self.generate_consensus(epoch+1)
del DirAuth.uploadeddescs[epoch+1]
if network.thenetwork.snipauthmode == \
network.SNIPAuthMode.THRESHSIG:
for s in DirAuth.endive.enddict['snips']:
s.auth(self.sigkey, self.perfstats)
elif network.thenetwork.snipauthmode == \
network.SNIPAuthMode.MERKLE:
# Construct the Merkle tree of the SNIPs in the ENDIVE
# and put the root in the consensus
tree = merklelib.MerkleTree(
[snip.__str__(False) \
for snip in DirAuth.endive.enddict['snips']],
merklelib.Hasher())
DirAuth.consensus.consdict["merkleroot"] = tree.merkle_root
else:
if network.thenetwork.snipauthmode == \
network.SNIPAuthMode.THRESHSIG:
for s in DirAuth.endive.enddict['snips']:
# We're just simulating threshold sigs by having
# only the first dirauth sign, but in reality each
# dirauth would contribute to the signature (at the
# same cost as each one signing), so we'll charge
# their perfstats as well
self.perfstats.sigs += 1
DirAuth.consensus.sign(self.sigkey, self.me, self.perfstats)
if network.thenetwork.womode != network.WOMode.VANILLA:
DirAuth.endive.sign(self.sigkey, self.me, self.perfstats)
def received(self, client, message):
self.perfstats.bytes_received += message.size()
if isinstance(message, simmsg.DirAuthUploadDescMsg):
# Check the uploaded descriptor for sanity
epoch = message.desc.descdict['epoch']
if epoch != network.thenetwork.getepoch() + 1:
return
# Store it in the class-static dict
if epoch not in DirAuth.uploadeddescs:
DirAuth.uploadeddescs[epoch] = dict()
descstr = str(message.desc)
if descstr not in DirAuth.uploadeddescs[epoch]:
DirAuth.uploadeddescs[epoch][descstr] = (1, message.desc)
else:
DirAuth.uploadeddescs[epoch][descstr] = \
(DirAuth.uploadeddescs[epoch][descstr][0]+1,
DirAuth.uploadeddescs[epoch][descstr][1])
elif isinstance(message, simmsg.DirAuthDelDescMsg):
# Check the uploaded descriptor for sanity
epoch = message.desc.descdict['epoch']
if epoch != network.thenetwork.getepoch() + 1:
return
# Remove it from the class-static dict
if epoch not in DirAuth.uploadeddescs:
return
descstr = str(message.desc)
if descstr not in DirAuth.uploadeddescs[epoch]:
return
elif DirAuth.uploadeddescs[epoch][descstr][0] == 1:
del DirAuth.uploadeddescs[epoch][descstr]
else:
DirAuth.uploadeddescs[epoch][descstr] = \
(DirAuth.uploadeddescs[epoch][descstr][0]-1,
DirAuth.uploadeddescs[epoch][descstr][1])
elif isinstance(message, simmsg.DirAuthGetConsensusMsg):
replymsg = simmsg.DirAuthConsensusMsg(DirAuth.consensus)
msgsize = replymsg.size()
self.perfstats.bytes_sent += msgsize
client.reply(replymsg)
elif isinstance(message, simmsg.DirAuthGetConsensusDiffMsg):
replymsg = simmsg.DirAuthConsensusDiffMsg(DirAuth.consensus)
msgsize = replymsg.size()
self.perfstats.bytes_sent += msgsize
client.reply(replymsg)
elif isinstance(message, simmsg.DirAuthGetENDIVEMsg):
replymsg = simmsg.DirAuthENDIVEMsg(DirAuth.endive)
msgsize = replymsg.size()
self.perfstats.bytes_sent += msgsize
client.reply(replymsg)
elif isinstance(message, simmsg.DirAuthGetENDIVEDiffMsg):
replymsg = simmsg.DirAuthENDIVEDiffMsg(DirAuth.endive)
msgsize = replymsg.size()
self.perfstats.bytes_sent += msgsize
client.reply(replymsg)
else:
raise TypeError('Not a client-originating DirAuthNetMsg', message)
def closed(self):
pass