Skip to content
Snippets Groups Projects
Commit a2866e2e authored by Mark Haines's avatar Mark Haines
Browse files

Rename direction to step, apply checks consistently

parent e36bfbab
No related branches found
No related tags found
No related merge requests found
...@@ -97,7 +97,7 @@ class DataStore(RoomMemberStore, RoomStore, ...@@ -97,7 +97,7 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "events", "stream_ordering" db_conn, "events", "stream_ordering"
) )
self._backfill_id_gen = StreamIdGenerator( self._backfill_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering", direction=-1 db_conn, "events", "stream_ordering", step=-1
) )
self._receipts_id_gen = StreamIdGenerator( self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id" db_conn, "receipts_linearized", "stream_id"
......
...@@ -29,16 +29,16 @@ class IdGenerator(object): ...@@ -29,16 +29,16 @@ class IdGenerator(object):
return self._next_id return self._next_id
def _load_current_id(db_conn, table, column, direction=1): def _load_current_id(db_conn, table, column, step=1):
cur = db_conn.cursor() cur = db_conn.cursor()
if direction == 1: if step == 1:
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
else: else:
cur.execute("SELECT MIN(%s) FROM %s" % (column, table,)) cur.execute("SELECT MIN(%s) FROM %s" % (column, table,))
val, = cur.fetchone() val, = cur.fetchone()
cur.close() cur.close()
current_id = int(val) if val else direction current_id = int(val) if val else step
return (max if direction == 1 else min)(current_id, direction) return (max if step > 0 else min)(current_id, step)
class StreamIdGenerator(object): class StreamIdGenerator(object):
...@@ -58,21 +58,21 @@ class StreamIdGenerator(object): ...@@ -58,21 +58,21 @@ class StreamIdGenerator(object):
:param list extra_tables: List of pairs of database tables and columns to :param list extra_tables: List of pairs of database tables and columns to
use to source the initial value of the generator from. The value with use to source the initial value of the generator from. The value with
the largest magnitude is used. the largest magnitude is used.
:param int direction: which direction the stream ids grow in. +1 to grow :param int step: which direction the stream ids grow in. +1 to grow
upwards, -1 to grow downwards. upwards, -1 to grow downwards.
Usage: Usage:
with stream_id_gen.get_next() as stream_id: with stream_id_gen.get_next() as stream_id:
# ... persist event ... # ... persist event ...
""" """
def __init__(self, db_conn, table, column, extra_tables=[], direction=1): def __init__(self, db_conn, table, column, extra_tables=[], step=1):
self._lock = threading.Lock() self._lock = threading.Lock()
self._direction = direction self._step = step
self._current = _load_current_id(db_conn, table, column, direction) self._current = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables: for table, column in extra_tables:
self._current = (max if direction > 0 else min)( self._current = (max if step > 0 else min)(
self._current, self._current,
_load_current_id(db_conn, table, column, direction) _load_current_id(db_conn, table, column, step)
) )
self._unfinished_ids = deque() self._unfinished_ids = deque()
...@@ -83,7 +83,7 @@ class StreamIdGenerator(object): ...@@ -83,7 +83,7 @@ class StreamIdGenerator(object):
# ... persist event ... # ... persist event ...
""" """
with self._lock: with self._lock:
self._current += self._direction self._current += self._step
next_id = self._current next_id = self._current
self._unfinished_ids.append(next_id) self._unfinished_ids.append(next_id)
...@@ -106,9 +106,9 @@ class StreamIdGenerator(object): ...@@ -106,9 +106,9 @@ class StreamIdGenerator(object):
""" """
with self._lock: with self._lock:
next_ids = range( next_ids = range(
self._current + self._direction, self._current + self._step,
self._current + self._direction * (n + 1), self._current + self._step * (n + 1),
self._direction self._step
) )
self._current += n self._current += n
...@@ -132,7 +132,7 @@ class StreamIdGenerator(object): ...@@ -132,7 +132,7 @@ class StreamIdGenerator(object):
""" """
with self._lock: with self._lock:
if self._unfinished_ids: if self._unfinished_ids:
return self._unfinished_ids[0] - self._direction return self._unfinished_ids[0] - self._step
return self._current return self._current
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment