Merge pull request #159 from OpenKMIP/bug/fix-early-close

Fixing bug terminating connection prematurely
This commit is contained in:
Peter Hamilton 2016-04-05 14:09:41 -04:00
commit b152941b68
3 changed files with 46 additions and 13 deletions

View File

@ -211,6 +211,14 @@ class ConfigurationError(Exception):
pass pass
class ConnectionClosed(Exception):
"""
An exception generated when attempting to use a connection that has been
closed.
"""
pass
class NetworkingError(Exception): class NetworkingError(Exception):
""" """
An error generated when a problem occurs with client or server networking An error generated when a problem occurs with client or server networking

View File

@ -70,15 +70,18 @@ class KmipSession(threading.Thread):
""" """
self._logger.info("Starting session: {0}".format(self.name)) self._logger.info("Starting session: {0}".format(self.name))
try: while True:
self._handle_message_loop() try:
except Exception as e: self._handle_message_loop()
self._logger.info("Failure handling message loop") except exceptions.ConnectionClosed as e:
self._logger.exception(e) break
finally: except Exception as e:
self._connection.shutdown(socket.SHUT_RDWR) self._logger.info("Failure handling message loop")
self._connection.close() self._logger.exception(e)
self._logger.info("Stopping session: {0}".format(self.name))
self._connection.shutdown(socket.SHUT_RDWR)
self._connection.close()
self._logger.info("Stopping session: {0}".format(self.name))
def _handle_message_loop(self): def _handle_message_loop(self):
request_data = self._receive_request() request_data = self._receive_request()
@ -164,6 +167,8 @@ class KmipSession(threading.Thread):
if partial_message is None: if partial_message is None:
break break
elif len(partial_message) == 0:
raise exceptions.ConnectionClosed()
else: else:
bytes_received += len(partial_message) bytes_received += len(partial_message)
message += partial_message message += partial_message

View File

@ -19,6 +19,7 @@ import testtools
import time import time
from kmip.core import enums from kmip.core import enums
from kmip.core import exceptions
from kmip.core import utils from kmip.core import utils
from kmip.core.messages import contents from kmip.core.messages import contents
@ -58,13 +59,18 @@ class TestKmipSession(testtools.TestCase):
""" """
kmip_session = session.KmipSession(None, None, 'name') kmip_session = session.KmipSession(None, None, 'name')
kmip_session._logger = mock.MagicMock() kmip_session._logger = mock.MagicMock()
kmip_session._handle_message_loop = mock.MagicMock() kmip_session._handle_message_loop = mock.MagicMock(
side_effect=[
None,
exceptions.ConnectionClosed()
]
)
kmip_session._connection = mock.MagicMock() kmip_session._connection = mock.MagicMock()
kmip_session.run() kmip_session.run()
kmip_session._logger.info.assert_any_call("Starting session: name") kmip_session._logger.info.assert_any_call("Starting session: name")
kmip_session._handle_message_loop.assert_called_once_with() self.assertTrue(kmip_session._handle_message_loop.called)
kmip_session._connection.shutdown.assert_called_once_with( kmip_session._connection.shutdown.assert_called_once_with(
socket.SHUT_RDWR socket.SHUT_RDWR
) )
@ -82,13 +88,16 @@ class TestKmipSession(testtools.TestCase):
test_exception = Exception("test") test_exception = Exception("test")
kmip_session._handle_message_loop = mock.MagicMock( kmip_session._handle_message_loop = mock.MagicMock(
side_effect=test_exception side_effect=[
test_exception,
exceptions.ConnectionClosed()
]
) )
kmip_session.run() kmip_session.run()
kmip_session._logger.info.assert_any_call("Starting session: name") kmip_session._logger.info.assert_any_call("Starting session: name")
kmip_session._handle_message_loop.assert_called_once_with() self.assertTrue(kmip_session._handle_message_loop.called)
kmip_session._logger.info.assert_any_call( kmip_session._logger.info.assert_any_call(
"Failure handling message loop" "Failure handling message loop"
) )
@ -267,6 +276,17 @@ class TestKmipSession(testtools.TestCase):
kmip_session._connection.recv.assert_called_with(8) kmip_session._connection.recv.assert_called_with(8)
self.assertEqual(content + content, observed) self.assertEqual(content + content, observed)
kmip_session._connection.recv = mock.MagicMock(
side_effect=['']
)
args = (8, )
self.assertRaises(
exceptions.ConnectionClosed,
kmip_session._receive_bytes,
*args
)
def test_receive_bytes_with_bad_length(self): def test_receive_bytes_with_bad_length(self):
""" """
Test that the session generates an error on an incorrectly sized Test that the session generates an error on an incorrectly sized