diff --git a/kmip/core/exceptions.py b/kmip/core/exceptions.py index c28b5bb..0f5695e 100644 --- a/kmip/core/exceptions.py +++ b/kmip/core/exceptions.py @@ -211,6 +211,14 @@ class ConfigurationError(Exception): pass +class ConnectionClosed(Exception): + """ + An exception generated when attempting to use a connection that has been + closed. + """ + pass + + class NetworkingError(Exception): """ An error generated when a problem occurs with client or server networking diff --git a/kmip/services/server/session.py b/kmip/services/server/session.py index 314d290..3ff8df9 100644 --- a/kmip/services/server/session.py +++ b/kmip/services/server/session.py @@ -70,15 +70,18 @@ class KmipSession(threading.Thread): """ self._logger.info("Starting session: {0}".format(self.name)) - try: - self._handle_message_loop() - except Exception as e: - self._logger.info("Failure handling message loop") - self._logger.exception(e) - finally: - self._connection.shutdown(socket.SHUT_RDWR) - self._connection.close() - self._logger.info("Stopping session: {0}".format(self.name)) + while True: + try: + self._handle_message_loop() + except exceptions.ConnectionClosed as e: + break + except Exception as e: + self._logger.info("Failure handling message loop") + self._logger.exception(e) + + self._connection.shutdown(socket.SHUT_RDWR) + self._connection.close() + self._logger.info("Stopping session: {0}".format(self.name)) def _handle_message_loop(self): request_data = self._receive_request() @@ -164,6 +167,8 @@ class KmipSession(threading.Thread): if partial_message is None: break + elif len(partial_message) == 0: + raise exceptions.ConnectionClosed() else: bytes_received += len(partial_message) message += partial_message diff --git a/kmip/tests/unit/services/server/test_session.py b/kmip/tests/unit/services/server/test_session.py index 1580237..e554d4a 100644 --- a/kmip/tests/unit/services/server/test_session.py +++ b/kmip/tests/unit/services/server/test_session.py @@ -19,6 +19,7 @@ import testtools import time from kmip.core import enums +from kmip.core import exceptions from kmip.core import utils from kmip.core.messages import contents @@ -58,13 +59,18 @@ class TestKmipSession(testtools.TestCase): """ kmip_session = session.KmipSession(None, None, 'name') kmip_session._logger = mock.MagicMock() - kmip_session._handle_message_loop = mock.MagicMock() + kmip_session._handle_message_loop = mock.MagicMock( + side_effect=[ + None, + exceptions.ConnectionClosed() + ] + ) kmip_session._connection = mock.MagicMock() kmip_session.run() kmip_session._logger.info.assert_any_call("Starting session: name") - kmip_session._handle_message_loop.assert_called_once_with() + self.assertTrue(kmip_session._handle_message_loop.called) kmip_session._connection.shutdown.assert_called_once_with( socket.SHUT_RDWR ) @@ -82,13 +88,16 @@ class TestKmipSession(testtools.TestCase): test_exception = Exception("test") kmip_session._handle_message_loop = mock.MagicMock( - side_effect=test_exception + side_effect=[ + test_exception, + exceptions.ConnectionClosed() + ] ) kmip_session.run() kmip_session._logger.info.assert_any_call("Starting session: name") - kmip_session._handle_message_loop.assert_called_once_with() + self.assertTrue(kmip_session._handle_message_loop.called) kmip_session._logger.info.assert_any_call( "Failure handling message loop" ) @@ -267,6 +276,17 @@ class TestKmipSession(testtools.TestCase): kmip_session._connection.recv.assert_called_with(8) self.assertEqual(content + content, observed) + kmip_session._connection.recv = mock.MagicMock( + side_effect=[''] + ) + + args = (8, ) + self.assertRaises( + exceptions.ConnectionClosed, + kmip_session._receive_bytes, + *args + ) + def test_receive_bytes_with_bad_length(self): """ Test that the session generates an error on an incorrectly sized