Skip to content
Snippets Groups Projects
Commit 1432f7cc authored by Erik Johnston's avatar Erik Johnston
Browse files

Move test stuff to tests

parent 2f18a264
Branches
Tags
No related merge requests found
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer, threads, reactor from twisted.internet import threads, reactor
from synapse.util.logcontext import make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable
...@@ -57,10 +57,6 @@ class BackgroundFileConsumer(object): ...@@ -57,10 +57,6 @@ class BackgroundFileConsumer(object):
# If the _writer thread throws an exception it gets stored here. # If the _writer thread throws an exception it gets stored here.
self._write_exception = None self._write_exception = None
# A deferred that gets resolved when the bytes_queue gets empty.
# Mainly used for tests.
self._notify_empty_deferred = None
def registerProducer(self, producer, streaming): def registerProducer(self, producer, streaming):
"""Part of IConsumer interface """Part of IConsumer interface
...@@ -113,9 +109,6 @@ class BackgroundFileConsumer(object): ...@@ -113,9 +109,6 @@ class BackgroundFileConsumer(object):
if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE: if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE:
reactor.callFromThread(self._resume_paused_producer) reactor.callFromThread(self._resume_paused_producer)
if self._notify_empty_deferred and self._bytes_queue.empty():
reactor.callFromThread(self._notify_empty)
bytes = self._bytes_queue.get() bytes = self._bytes_queue.get()
# If we get a None (or empty list) then that's a signal used # If we get a None (or empty list) then that's a signal used
...@@ -144,20 +137,3 @@ class BackgroundFileConsumer(object): ...@@ -144,20 +137,3 @@ class BackgroundFileConsumer(object):
if self._paused_producer and self._producer: if self._paused_producer and self._producer:
self._paused_producer = False self._paused_producer = False
self._producer.resumeProducing() self._producer.resumeProducing()
def _notify_empty(self):
"""Called when the _writer thread thinks the queue may be empty and
we should notify anything waiting on `wait_for_writes`
"""
if self._notify_empty_deferred and self._bytes_queue.empty():
d = self._notify_empty_deferred
self._notify_empty_deferred = None
d.callback(None)
def wait_for_writes(self):
"""Wait for the write queue to be empty and for writes to have
finished. This is mainly useful for tests.
"""
if not self._notify_empty_deferred:
self._notify_empty_deferred = defer.Deferred()
return self._notify_empty_deferred
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer, reactor
from mock import NonCallableMock from mock import NonCallableMock
from synapse.util.file_consumer import BackgroundFileConsumer from synapse.util.file_consumer import BackgroundFileConsumer
...@@ -53,7 +53,7 @@ class FileConsumerTests(unittest.TestCase): ...@@ -53,7 +53,7 @@ class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_push_consumer(self): def test_push_consumer(self):
string_file = StringIO() string_file = BlockingStringWrite()
consumer = BackgroundFileConsumer(string_file) consumer = BackgroundFileConsumer(string_file)
try: try:
...@@ -62,14 +62,14 @@ class FileConsumerTests(unittest.TestCase): ...@@ -62,14 +62,14 @@ class FileConsumerTests(unittest.TestCase):
consumer.registerProducer(producer, True) consumer.registerProducer(producer, True)
consumer.write("Foo") consumer.write("Foo")
yield consumer.wait_for_writes() yield string_file.wait_for_n_writes(1)
self.assertEqual(string_file.getvalue(), "Foo") self.assertEqual(string_file.buffer, "Foo")
consumer.write("Bar") consumer.write("Bar")
yield consumer.wait_for_writes() yield string_file.wait_for_n_writes(2)
self.assertEqual(string_file.getvalue(), "FooBar") self.assertEqual(string_file.buffer, "FooBar")
finally: finally:
consumer.unregisterProducer() consumer.unregisterProducer()
...@@ -85,15 +85,22 @@ class FileConsumerTests(unittest.TestCase): ...@@ -85,15 +85,22 @@ class FileConsumerTests(unittest.TestCase):
try: try:
producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"]) producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
resume_deferred = defer.Deferred()
producer.resumeProducing.side_effect = lambda: resume_deferred.callback(None)
consumer.registerProducer(producer, True) consumer.registerProducer(producer, True)
number_writes = 0
with string_file.write_lock: with string_file.write_lock:
for _ in range(consumer._PAUSE_ON_QUEUE_SIZE): for _ in range(consumer._PAUSE_ON_QUEUE_SIZE):
consumer.write("Foo") consumer.write("Foo")
number_writes += 1
producer.pauseProducing.assert_called_once() producer.pauseProducing.assert_called_once()
yield consumer.wait_for_writes() yield string_file.wait_for_n_writes(number_writes)
yield resume_deferred
producer.resumeProducing.assert_called_once() producer.resumeProducing.assert_called_once()
finally: finally:
consumer.unregisterProducer() consumer.unregisterProducer()
...@@ -131,8 +138,39 @@ class BlockingStringWrite(object): ...@@ -131,8 +138,39 @@ class BlockingStringWrite(object):
self.closed = False self.closed = False
self.write_lock = threading.Lock() self.write_lock = threading.Lock()
self._notify_write_deferred = None
self._number_of_writes = 0
def write(self, bytes): def write(self, bytes):
self.buffer += bytes with self.write_lock:
self.buffer += bytes
self._number_of_writes += 1
reactor.callFromThread(self._notify_write)
def close(self): def close(self):
self.closed = True self.closed = True
def _notify_write(self):
"Called by write to indicate a write happened"
with self.write_lock:
if not self._notify_write_deferred:
return
d = self._notify_write_deferred
self._notify_write_deferred = None
d.callback(None)
@defer.inlineCallbacks
def wait_for_n_writes(self, n):
"Wait for n writes to have happened"
while True:
with self.write_lock:
if n <= self._number_of_writes:
return
if not self._notify_write_deferred:
self._notify_write_deferred = defer.Deferred()
d = self._notify_write_deferred
yield d
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment