diff --git a/tests/test_classes.py b/tests/test_classes.py index 2c805dc..9d454de 100644 --- a/tests/test_classes.py +++ b/tests/test_classes.py @@ -95,22 +95,29 @@ async def test_eventworker() -> None: transport = mock.Mock() transport.write = mock.Mock() transport.is_closing = mock.Mock() - protocol._drain_helper = make_mocked_coro() - - loop = asyncio.get_event_loop() - writer = asyncio.StreamWriter(transport, protocol, None, loop) - - worker: pytak.Worker = pytak.TXWorker(event_queue, {}, writer) - - await worker.run_once() - remaining_event = await event_queue.get() - assert b"taco2" == remaining_event - - popped = transport.write.mock_calls.pop() - - # Python 3.7: popped[1][0] - # Python 3.8+: popped.args[0] - assert b"taco1" == popped[1][0] + + protocol = mock.Mock() + protocol._drain_helper = mock.AsyncMock() + + mock_reader = mock.Mock(spec=asyncio.StreamReader) + mock_writer = mock.Mock(spec=asyncio.StreamWriter) + mock_writer.transport = transport + + mock_writer.write = transport.write + + with mock.patch('asyncio.open_connection', new=mock.AsyncMock(return_value=(mock_reader, mock_writer))): + _, writer = await asyncio.open_connection() + worker: pytak.Worker = pytak.TXWorker(event_queue, {}, writer) + + await worker.run_once() + + remaining_event = await event_queue.get() + assert b"taco2" == remaining_event + + popped = transport.write.mock_calls.pop() + # Python 3.7: popped[1][0] + # Python 3.8+: popped.args[0] + assert b"taco1" == popped[1][0] def test_simple_cot_event_to_xml() -> None: