From f6b420d2dbacd5795a80c0e520ce06790c9f6163 Mon Sep 17 00:00:00 2001 From: Peter Hamilton Date: Fri, 23 Jan 2015 15:26:28 -0500 Subject: [PATCH] Adding support for the DiscoverVersions operation This change adds support for the DiscoverVersions operation, including updates to the KMIP client, the client and KMIP core test suites, and a DiscoverVersions unit demo. --- kmip/core/factories/payloads/request.py | 4 + kmip/core/factories/payloads/response.py | 4 + kmip/core/messages/contents.py | 155 ++++++---- .../messages/payloads/discover_versions.py | 132 +++++++++ kmip/core/primitives.py | 5 +- kmip/demos/units/discover_versions.py | 67 +++++ kmip/demos/utils.py | 2 + kmip/services/kmip_client.py | 36 +++ kmip/services/results.py | 12 + .../core/factories/payloads/test_request.py | 6 +- .../core/factories/payloads/test_response.py | 6 +- .../__init__.py} | 2 +- .../contents/test_protocol_version.py | 198 +++++++++++++ .../payloads/test_discover_versions.py | 277 ++++++++++++++++++ kmip/tests/services/test_kmip_client.py | 72 +++++ 15 files changed, 923 insertions(+), 55 deletions(-) create mode 100644 kmip/core/messages/payloads/discover_versions.py create mode 100644 kmip/demos/units/discover_versions.py rename kmip/tests/core/messages/{test_contents.py => contents/__init__.py} (90%) create mode 100644 kmip/tests/core/messages/contents/test_protocol_version.py create mode 100644 kmip/tests/core/messages/payloads/test_discover_versions.py diff --git a/kmip/core/factories/payloads/request.py b/kmip/core/factories/payloads/request.py index b437c27..adc9958 100644 --- a/kmip/core/factories/payloads/request.py +++ b/kmip/core/factories/payloads/request.py @@ -18,6 +18,7 @@ from kmip.core.factories.payloads import PayloadFactory from kmip.core.messages.payloads import create from kmip.core.messages.payloads import create_key_pair from kmip.core.messages.payloads import destroy +from kmip.core.messages.payloads import discover_versions from kmip.core.messages.payloads import get from kmip.core.messages.payloads import locate from kmip.core.messages.payloads import rekey_key_pair @@ -46,3 +47,6 @@ class RequestPayloadFactory(PayloadFactory): def _create_destroy_payload(self): return destroy.DestroyRequestPayload() + + def _create_discover_versions_payload(self): + return discover_versions.DiscoverVersionsRequestPayload() diff --git a/kmip/core/factories/payloads/response.py b/kmip/core/factories/payloads/response.py index ee40587..69a7c31 100644 --- a/kmip/core/factories/payloads/response.py +++ b/kmip/core/factories/payloads/response.py @@ -18,6 +18,7 @@ from kmip.core.factories.payloads import PayloadFactory from kmip.core.messages.payloads import create from kmip.core.messages.payloads import create_key_pair from kmip.core.messages.payloads import destroy +from kmip.core.messages.payloads import discover_versions from kmip.core.messages.payloads import get from kmip.core.messages.payloads import locate from kmip.core.messages.payloads import rekey_key_pair @@ -46,3 +47,6 @@ class ResponsePayloadFactory(PayloadFactory): def _create_destroy_payload(self): return destroy.DestroyResponsePayload() + + def _create_discover_versions_payload(self): + return discover_versions.DiscoverVersionsResponsePayload() diff --git a/kmip/core/messages/contents.py b/kmip/core/messages/contents.py index 5794758..ca7e2c7 100644 --- a/kmip/core/messages/contents.py +++ b/kmip/core/messages/contents.py @@ -13,7 +13,14 @@ # License for the specific language governing permissions and limitations # under the License. -from kmip.core import enums +from kmip.core.enums import BatchErrorContinuationOption +from kmip.core.enums import KeyCompressionType +from kmip.core.enums import KeyFormatType +from kmip.core.enums import Operation +from kmip.core.enums import ResultStatus +from kmip.core.enums import ResultReason +from kmip.core.enums import Tags + from kmip.core import objects from kmip.core import utils @@ -31,28 +38,38 @@ class ProtocolVersion(Struct): class ProtocolVersionMajor(Integer): def __init__(self, value=None): - super(self.__class__, self).\ - __init__(value, enums.Tags.PROTOCOL_VERSION_MAJOR) + super(ProtocolVersion.ProtocolVersionMajor, self).\ + __init__(value, Tags.PROTOCOL_VERSION_MAJOR) class ProtocolVersionMinor(Integer): def __init__(self, value=None): - super(self.__class__, self).\ - __init__(value, enums.Tags.PROTOCOL_VERSION_MINOR) + super(ProtocolVersion.ProtocolVersionMinor, self).\ + __init__(value, Tags.PROTOCOL_VERSION_MINOR) def __init__(self, protocol_version_major=None, protocol_version_minor=None): - super(self.__class__, self).__init__(tag=enums.Tags.PROTOCOL_VERSION) - self.protocol_version_major = protocol_version_major - self.protocol_version_minor = protocol_version_minor + super(ProtocolVersion, self).__init__(Tags.PROTOCOL_VERSION) + + if protocol_version_major is None: + self.protocol_version_major = \ + ProtocolVersion.ProtocolVersionMajor() + else: + self.protocol_version_major = protocol_version_major + + if protocol_version_minor is None: + self.protocol_version_minor = \ + ProtocolVersion.ProtocolVersionMinor() + else: + self.protocol_version_minor = protocol_version_minor + + self.validate() def read(self, istream): - super(self.__class__, self).read(istream) + super(ProtocolVersion, self).read(istream) tstream = utils.BytearrayStream(istream.read(self.length)) # Read the major and minor portions of the version number - self.protocol_version_major = ProtocolVersion.ProtocolVersionMajor() - self.protocol_version_minor = ProtocolVersion.ProtocolVersionMinor() self.protocol_version_major.read(tstream) self.protocol_version_minor.read(tstream) @@ -67,57 +84,97 @@ class ProtocolVersion(Struct): # Write the length and value of the protocol version self.length = tstream.length() - super(self.__class__, self).write(ostream) + super(ProtocolVersion, self).write(ostream) ostream.write(tstream.buffer) def validate(self): - # TODO (peter-hamilton) Finish implementation. - pass + self.__validate() + + def __validate(self): + if not isinstance(self.protocol_version_major, + ProtocolVersion.ProtocolVersionMajor): + msg = "invalid protocol version major" + msg += "; expected {0}, received {1}".format( + ProtocolVersion.ProtocolVersionMajor, + self.protocol_version_major) + raise TypeError(msg) + + if not isinstance(self.protocol_version_minor, + ProtocolVersion.ProtocolVersionMinor): + msg = "invalid protocol version minor" + msg += "; expected {0}, received {1}".format( + ProtocolVersion.ProtocolVersionMinor, + self.protocol_version_minor) + raise TypeError(msg) + + def __eq__(self, other): + if isinstance(other, ProtocolVersion): + if ((self.protocol_version_major == + other.protocol_version_major) + and + (self.protocol_version_minor == + other.protocol_version_minor)): + return True + else: + return False + else: + return NotImplemented + + def __ne__(self, other): + if isinstance(other, ProtocolVersion): + return not self.__eq__(other) + else: + return NotImplemented + + def __repr__(self): + major = self.protocol_version_major.value + minor = self.protocol_version_minor.value + return "{0}.{1}".format(major, minor) @classmethod def create(cls, major, minor): - major_version = cls.ProtocolVersionMajor(major) - minor_version = cls.ProtocolVersionMinor(minor) - return ProtocolVersion(major_version, minor_version) + major = cls.ProtocolVersionMajor(major) + minor = cls.ProtocolVersionMinor(minor) + return ProtocolVersion(major, minor) # 6.2 class Operation(Enumeration): - ENUM_TYPE = enums.Operation + ENUM_TYPE = Operation def __init__(self, value=None): - super(self.__class__, self).__init__(value, enums.Tags.OPERATION) + super(Operation, self).__init__(value, Tags.OPERATION) # 6.3 class MaximumResponseSize(Integer): def __init__(self, value=None): - super(self.__class__, self).\ - __init__(value, enums.Tags.MAXIMUM_RESPONSE_SIZE) + super(MaximumResponseSize, self).\ + __init__(value, Tags.MAXIMUM_RESPONSE_SIZE) # 6.4 class UniqueBatchItemID(ByteString): def __init__(self, value=None): - super(self.__class__, self)\ - .__init__(value, enums.Tags.UNIQUE_BATCH_ITEM_ID) + super(UniqueBatchItemID, self)\ + .__init__(value, Tags.UNIQUE_BATCH_ITEM_ID) # 6.5 class TimeStamp(DateTime): def __init__(self, value=None): - super(self.__class__, self).__init__(value, enums.Tags.TIME_STAMP) + super(TimeStamp, self).__init__(value, Tags.TIME_STAMP) # 6.6 class Authentication(Struct): def __init__(self, credential=None): - super(self.__class__, self).__init__(tag=enums.Tags.AUTHENTICATION) + super(Authentication, self).__init__(Tags.AUTHENTICATION) self.credential = credential def read(self, istream): - super(self.__class__, self).read(istream) + super(Authentication, self).read(istream) tstream = utils.BytearrayStream(istream.read(self.length)) # Read the credential @@ -134,7 +191,7 @@ class Authentication(Struct): # Write the length and value of the protocol version self.length = tstream.length() - super(self.__class__, self).write(ostream) + super(Authentication, self).write(ostream) ostream.write(tstream.buffer) def validate(self): @@ -145,80 +202,80 @@ class Authentication(Struct): # 6.7 class AsynchronousIndicator(Boolean): def __init__(self, value=None): - super(self.__class__, self).\ - __init__(value, enums.Tags.ASYNCHRONOUS_INDICATOR) + super(AsynchronousIndicator, self).\ + __init__(value, Tags.ASYNCHRONOUS_INDICATOR) # 6.8 class AsynchronousCorrelationValue(ByteString): def __init__(self, value=None): - super(self.__class__, self).\ - __init__(value, enums.Tags.ASYNCHRONOUS_CORRELATION_VALUE) + super(AsynchronousCorrelationValue, self).\ + __init__(value, Tags.ASYNCHRONOUS_CORRELATION_VALUE) # 6.9 class ResultStatus(Enumeration): - ENUM_TYPE = enums.ResultStatus + ENUM_TYPE = ResultStatus def __init__(self, value=None): - super(self.__class__, self).__init__(value, enums.Tags.RESULT_STATUS) + super(ResultStatus, self).__init__(value, Tags.RESULT_STATUS) # 6.10 class ResultReason(Enumeration): - ENUM_TYPE = enums.ResultReason + ENUM_TYPE = ResultReason def __init__(self, value=None): - super(self.__class__, self).__init__(value, enums.Tags.RESULT_REASON) + super(ResultReason, self).__init__(value, Tags.RESULT_REASON) # 6.11 class ResultMessage(TextString): def __init__(self, value=None): - super(self.__class__, self).__init__(value, enums.Tags.RESULT_MESSAGE) + super(ResultMessage, self).__init__(value, Tags.RESULT_MESSAGE) # 6.12 class BatchOrderOption(Boolean): def __init__(self, value=None): - super(self.__class__, self).\ - __init__(value, enums.Tags.BATCH_ORDER_OPTION) + super(BatchOrderOption, self).\ + __init__(value, Tags.BATCH_ORDER_OPTION) # 6.13 class BatchErrorContinuationOption(Enumeration): - ENUM_TYPE = enums.BatchErrorContinuationOption + ENUM_TYPE = BatchErrorContinuationOption def __init__(self, value=None): - super(self.__class__, self).\ - __init__(value, enums.Tags.BATCH_ERROR_CONTINUATION_OPTION) + super(BatchErrorContinuationOption, self).\ + __init__(value, Tags.BATCH_ERROR_CONTINUATION_OPTION) # 6.14 class BatchCount(Integer): def __init__(self, value=None): - super(self.__class__, self).__init__(value, enums.Tags.BATCH_COUNT) + super(BatchCount, self).__init__(value, Tags.BATCH_COUNT) # 6.16 class MessageExtension(Struct): def __init__(self): - super(self.__class__, self).__init__(tag=enums.Tags.MESSAGE_EXTENSION) + super(MessageExtension, self).__init__(Tags.MESSAGE_EXTENSION) # 9.1.3.2.2 class KeyCompressionType(Enumeration): - ENUM_TYPE = enums.KeyCompressionType + ENUM_TYPE = KeyCompressionType def __init__(self, value=None): - super(self.__class__, self).\ - __init__(value, enums.Tags.KEY_COMPRESSION_TYPE) + super(KeyCompressionType, self).\ + __init__(value, Tags.KEY_COMPRESSION_TYPE) # 9.1.3.2.3 class KeyFormatType(Enumeration): - ENUM_TYPE = enums.KeyFormatType + ENUM_TYPE = KeyFormatType def __init__(self, value=None): - super(self.__class__, self).\ - __init__(value, enums.Tags.KEY_FORMAT_TYPE) + super(KeyFormatType, self).\ + __init__(value, Tags.KEY_FORMAT_TYPE) diff --git a/kmip/core/messages/payloads/discover_versions.py b/kmip/core/messages/payloads/discover_versions.py new file mode 100644 index 0000000..ad768ba --- /dev/null +++ b/kmip/core/messages/payloads/discover_versions.py @@ -0,0 +1,132 @@ +# Copyright (c) 2015 The Johns Hopkins University/Applied Physics Laboratory +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from six.moves import xrange + +from kmip.core.enums import Tags + +from kmip.core.messages.contents import ProtocolVersion + +from kmip.core.primitives import Struct + +from kmip.core.utils import BytearrayStream + + +class DiscoverVersionsRequestPayload(Struct): + + def __init__(self, protocol_versions=None): + super(DiscoverVersionsRequestPayload, self).__init__( + Tags.REQUEST_PAYLOAD) + + if protocol_versions is None: + self.protocol_versions = list() + else: + self.protocol_versions = protocol_versions + + self.validate() + + def read(self, istream): + super(DiscoverVersionsRequestPayload, self).read(istream) + tstream = BytearrayStream(istream.read(self.length)) + + while(self.is_tag_next(Tags.PROTOCOL_VERSION, tstream)): + protocol_version = ProtocolVersion() + protocol_version.read(tstream) + self.protocol_versions.append(protocol_version) + + self.is_oversized(tstream) + self.validate() + + def write(self, ostream): + tstream = BytearrayStream() + + for protocol_version in self.protocol_versions: + protocol_version.write(tstream) + + self.length = tstream.length() + super(DiscoverVersionsRequestPayload, self).write(ostream) + ostream.write(tstream.buffer) + + def validate(self): + self.__validate() + + def __validate(self): + if isinstance(self.protocol_versions, list): + for i in xrange(len(self.protocol_versions)): + protocol_version = self.protocol_versions[i] + if not isinstance(protocol_version, ProtocolVersion): + msg = "invalid protocol version ({0} in list)".format(i) + msg += "; expected {0}, received {1}".format( + ProtocolVersion, protocol_version) + raise TypeError(msg) + else: + msg = "invalid protocol versions list" + msg += "; expected {0}, received {1}".format( + list, self.protocol_versions) + raise TypeError(msg) + + +class DiscoverVersionsResponsePayload(Struct): + + def __init__(self, protocol_versions=None): + super(DiscoverVersionsResponsePayload, self).__init__( + Tags.RESPONSE_PAYLOAD) + + if protocol_versions is None: + self.protocol_versions = list() + else: + self.protocol_versions = protocol_versions + + self.validate() + + def read(self, istream): + super(DiscoverVersionsResponsePayload, self).read(istream) + tstream = BytearrayStream(istream.read(self.length)) + + while(self.is_tag_next(Tags.PROTOCOL_VERSION, tstream)): + protocol_version = ProtocolVersion() + protocol_version.read(tstream) + self.protocol_versions.append(protocol_version) + + self.is_oversized(tstream) + self.validate() + + def write(self, ostream): + tstream = BytearrayStream() + + for protocol_version in self.protocol_versions: + protocol_version.write(tstream) + + self.length = tstream.length() + super(DiscoverVersionsResponsePayload, self).write(ostream) + ostream.write(tstream.buffer) + + def validate(self): + self.__validate() + + def __validate(self): + if isinstance(self.protocol_versions, list): + for i in xrange(len(self.protocol_versions)): + protocol_version = self.protocol_versions[i] + if not isinstance(protocol_version, ProtocolVersion): + msg = "invalid protocol version ({0} in list)".format(i) + msg += "; expected {0}, received {1}".format( + ProtocolVersion, protocol_version) + raise TypeError(msg) + else: + msg = "invalid protocol versions list" + msg += "; expected {0}, received {1}".format( + list, self.protocol_versions) + raise TypeError(msg) diff --git a/kmip/core/primitives.py b/kmip/core/primitives.py index e7a1094..05e11a4 100644 --- a/kmip/core/primitives.py +++ b/kmip/core/primitives.py @@ -233,7 +233,10 @@ class Integer(Base): return NotImplemented def __ne__(self, other): - return not self.__eq__(other) + if isinstance(other, Integer): + return not self.__eq__(other) + else: + return NotImplemented class LongInteger(Base): diff --git a/kmip/demos/units/discover_versions.py b/kmip/demos/units/discover_versions.py new file mode 100644 index 0000000..4eac19c --- /dev/null +++ b/kmip/demos/units/discover_versions.py @@ -0,0 +1,67 @@ +# Copyright (c) 2015 The Johns Hopkins University/Applied Physics Laboratory +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from kmip.core.enums import Operation +from kmip.core.enums import ResultStatus + +from kmip.demos import utils + +from kmip.services.kmip_client import KMIPProxy + +import logging +import os +import sys + + +if __name__ == '__main__': + # Build and parse arguments + parser = utils.build_cli_parser(Operation.DISCOVER_VERSIONS) + opts, args = parser.parse_args(sys.argv[1:]) + + username = opts.username + password = opts.password + + # Build and setup logging and needed factories + f_log = os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, + 'logconfig.ini') + logging.config.fileConfig(f_log) + logger = logging.getLogger(__name__) + + # Build the client and connect to the server + client = KMIPProxy() + client.open() + + result = client.discover_versions() + client.close() + + # Display operation results + logger.debug('discover_versions() result status: {}'.format( + result.result_status.enum)) + + if result.result_status.enum == ResultStatus.SUCCESS: + protocol_versions = result.protocol_versions + if isinstance(protocol_versions, list): + logger.debug('number of protocol versions returned: {0}'.format( + len(protocol_versions))) + for protocol_version in protocol_versions: + logging.debug('protocol version supported: {0}'.format( + protocol_version)) + else: + logger.debug('number of protocol versions returned: 0') + else: + logger.debug('discover_versions() result reason: {}'.format( + result.result_reason.enum)) + logger.debug('discover_versions() result message: {}'.format( + result.result_message.value)) diff --git a/kmip/demos/utils.py b/kmip/demos/utils.py index 5fc30b2..f678c4a 100644 --- a/kmip/demos/utils.py +++ b/kmip/demos/utils.py @@ -119,6 +119,8 @@ def build_cli_parser(operation): default=None, dest="length", help="Key length in bits (e.g., 128, 256)") + elif operation is Operation.DISCOVER_VERSIONS: + pass else: raise ValueError("unrecognized operation: {0}".format(operation)) diff --git a/kmip/services/kmip_client.py b/kmip/services/kmip_client.py index 8fd3d66..8e037b5 100644 --- a/kmip/services/kmip_client.py +++ b/kmip/services/kmip_client.py @@ -16,6 +16,7 @@ from kmip.services.results import CreateResult from kmip.services.results import CreateKeyPairResult from kmip.services.results import DestroyResult +from kmip.services.results import DiscoverVersionsResult from kmip.services.results import GetResult from kmip.services.results import LocateResult from kmip.services.results import RegisterResult @@ -41,6 +42,7 @@ from kmip.core.messages import messages from kmip.core.messages.payloads import create from kmip.core.messages.payloads import create_key_pair from kmip.core.messages.payloads import destroy +from kmip.core.messages.payloads import discover_versions from kmip.core.messages.payloads import get from kmip.core.messages.payloads import locate from kmip.core.messages.payloads import rekey_key_pair @@ -174,6 +176,19 @@ class KMIPProxy(KMIP): object_group_member=object_group_member, attributes=attributes, credential=credential) + def discover_versions(self, batch=False, protocol_versions=None, + credential=None): + batch_item = self._build_discover_versions_batch_item( + protocol_versions) + + if batch: + self.batch_items.append(batch_item) + else: + request = self._build_request_message(credential, [batch_item]) + response = self._send_and_receive_message(request) + results = self._process_batch_items(response) + return results[0] + def _create(self, object_type=None, template_attribute=None, @@ -242,6 +257,16 @@ class KMIPProxy(KMIP): operation=operation, request_payload=payload) return batch_item + def _build_discover_versions_batch_item(self, protocol_versions=None): + operation = Operation(OperationEnum.DISCOVER_VERSIONS) + + payload = discover_versions.DiscoverVersionsRequestPayload( + protocol_versions) + + batch_item = messages.RequestBatchItem( + operation=operation, request_payload=payload) + return batch_item + def _process_batch_items(self, response): results = [] for batch_item in response.batch_items: @@ -256,6 +281,8 @@ class KMIPProxy(KMIP): return self._process_create_key_pair_batch_item elif operation == OperationEnum.REKEY_KEY_PAIR: return self._process_rekey_key_pair_batch_item + elif operation == OperationEnum.DISCOVER_VERSIONS: + return self._process_discover_versions_batch_item else: raise ValueError("no processor for operation: {0}".format( operation)) @@ -290,6 +317,15 @@ class KMIPProxy(KMIP): return self._process_key_pair_batch_item( batch_item, RekeyKeyPairResult) + def _process_discover_versions_batch_item(self, batch_item): + payload = batch_item.response_payload + + result = DiscoverVersionsResult( + batch_item.result_status, batch_item.result_reason, + batch_item.result_message, payload.protocol_versions) + + return result + def _get(self, unique_identifier=None, key_format_type=None, diff --git a/kmip/services/results.py b/kmip/services/results.py index 10f5d02..bdf754b 100644 --- a/kmip/services/results.py +++ b/kmip/services/results.py @@ -172,3 +172,15 @@ class LocateResult(OperationResult): result_reason, result_message) self.uuids = uuids + + +class DiscoverVersionsResult(OperationResult): + + def __init__(self, + result_status, + result_reason=None, + result_message=None, + protocol_versions=None): + super(DiscoverVersionsResult, self).__init__( + result_status, result_reason, result_message) + self.protocol_versions = protocol_versions diff --git a/kmip/tests/core/factories/payloads/test_request.py b/kmip/tests/core/factories/payloads/test_request.py index 30428d6..9f2f2b7 100644 --- a/kmip/tests/core/factories/payloads/test_request.py +++ b/kmip/tests/core/factories/payloads/test_request.py @@ -21,6 +21,7 @@ from kmip.core.factories.payloads.request import RequestPayloadFactory from kmip.core.messages.payloads import create from kmip.core.messages.payloads import create_key_pair from kmip.core.messages.payloads import destroy +from kmip.core.messages.payloads import discover_versions from kmip.core.messages.payloads import get from kmip.core.messages.payloads import locate from kmip.core.messages.payloads import rekey_key_pair @@ -162,5 +163,6 @@ class TestRequestPayloadFactory(testtools.TestCase): payload, rekey_key_pair.RekeyKeyPairRequestPayload) def test_create_discover_versions_payload(self): - self._test_not_implemented( - self.factory.create, Operation.DISCOVER_VERSIONS) + payload = self.factory.create(Operation.DISCOVER_VERSIONS) + self._test_payload_type( + payload, discover_versions.DiscoverVersionsRequestPayload) diff --git a/kmip/tests/core/factories/payloads/test_response.py b/kmip/tests/core/factories/payloads/test_response.py index 605d0dd..b47c31c 100644 --- a/kmip/tests/core/factories/payloads/test_response.py +++ b/kmip/tests/core/factories/payloads/test_response.py @@ -21,6 +21,7 @@ from kmip.core.factories.payloads.response import ResponsePayloadFactory from kmip.core.messages.payloads import create from kmip.core.messages.payloads import create_key_pair from kmip.core.messages.payloads import destroy +from kmip.core.messages.payloads import discover_versions from kmip.core.messages.payloads import get from kmip.core.messages.payloads import locate from kmip.core.messages.payloads import rekey_key_pair @@ -162,5 +163,6 @@ class TestResponsePayloadFactory(testtools.TestCase): payload, rekey_key_pair.RekeyKeyPairResponsePayload) def test_create_discover_versions_payload(self): - self._test_not_implemented( - self.factory.create, Operation.DISCOVER_VERSIONS) + payload = self.factory.create(Operation.DISCOVER_VERSIONS) + self._test_payload_type( + payload, discover_versions.DiscoverVersionsResponsePayload) diff --git a/kmip/tests/core/messages/test_contents.py b/kmip/tests/core/messages/contents/__init__.py similarity index 90% rename from kmip/tests/core/messages/test_contents.py rename to kmip/tests/core/messages/contents/__init__.py index 87b311e..417e2f9 100644 --- a/kmip/tests/core/messages/test_contents.py +++ b/kmip/tests/core/messages/contents/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2014 The Johns Hopkins University/Applied Physics Laboratory +# Copyright (c) 2015 The Johns Hopkins University/Applied Physics Laboratory # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may diff --git a/kmip/tests/core/messages/contents/test_protocol_version.py b/kmip/tests/core/messages/contents/test_protocol_version.py new file mode 100644 index 0000000..9c1425c --- /dev/null +++ b/kmip/tests/core/messages/contents/test_protocol_version.py @@ -0,0 +1,198 @@ +# Copyright (c) 2014 The Johns Hopkins University/Applied Physics Laboratory +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from testtools import TestCase + +from kmip.core.messages.contents import ProtocolVersion +from kmip.core.utils import BytearrayStream + + +class TestProtocolVersion(TestCase): + + def setUp(self): + super(TestProtocolVersion, self).setUp() + + self.major_default = ProtocolVersion.ProtocolVersionMajor() + self.minor_default = ProtocolVersion.ProtocolVersionMinor() + self.major = ProtocolVersion.ProtocolVersionMajor(1) + self.minor = ProtocolVersion.ProtocolVersionMinor(1) + + self.encoding_default = BytearrayStream(( + b'\x42\x00\x69\x01\x00\x00\x00\x20\x42\x00\x6A\x02\x00\x00\x00\x04' + b'\x00\x00\x00\x00\x00\x00\x00\x00\x42\x00\x6B\x02\x00\x00\x00\x04' + b'\x00\x00\x00\x00\x00\x00\x00\x00')) + self.encoding = BytearrayStream(( + b'\x42\x00\x69\x01\x00\x00\x00\x20\x42\x00\x6A\x02\x00\x00\x00\x04' + b'\x00\x00\x00\x01\x00\x00\x00\x00\x42\x00\x6B\x02\x00\x00\x00\x04' + b'\x00\x00\x00\x01\x00\x00\x00\x00')) + + def tearDown(self): + super(TestProtocolVersion, self).tearDown() + + def _test_init(self, protocol_version_major, protocol_version_minor): + protocol_version = ProtocolVersion( + protocol_version_major, protocol_version_minor) + + if protocol_version_major is None: + self.assertEqual(ProtocolVersion.ProtocolVersionMajor(), + protocol_version.protocol_version_major) + else: + self.assertEqual(protocol_version_major, + protocol_version.protocol_version_major) + + if protocol_version_minor is None: + self.assertEqual(ProtocolVersion.ProtocolVersionMinor(), + protocol_version.protocol_version_minor) + else: + self.assertEqual(protocol_version_minor, + protocol_version.protocol_version_minor) + + def test_init_with_none(self): + self._test_init(None, None) + + def test_init_with_args(self): + major = ProtocolVersion.ProtocolVersionMajor(1) + minor = ProtocolVersion.ProtocolVersionMinor(0) + + self._test_init(major, minor) + + def test_validate_on_invalid_protocol_version_major(self): + major = "invalid" + minor = ProtocolVersion.ProtocolVersionMinor(0) + args = [major, minor] + + self.assertRaisesRegexp( + TypeError, "invalid protocol version major", self._test_init, + *args) + + def test_validate_on_invalid_protocol_version_minor(self): + major = ProtocolVersion.ProtocolVersionMajor(1) + minor = "invalid" + args = [major, minor] + + self.assertRaisesRegexp( + TypeError, "invalid protocol version minor", self._test_init, + *args) + + def _test_read(self, stream, major, minor): + protocol_version = ProtocolVersion() + protocol_version.read(stream) + + msg = "protocol version major decoding mismatch" + msg += "; expected {0}, received {1}".format( + major, protocol_version.protocol_version_major) + self.assertEqual(major, protocol_version.protocol_version_major, msg) + + msg = "protocol version minor decoding mismatch" + msg += "; expected {0}, received {1}".format( + minor, protocol_version.protocol_version_minor) + self.assertEqual(minor, protocol_version.protocol_version_minor, msg) + + def test_read_with_none(self): + self._test_read(self.encoding_default, self.major_default, + self.minor_default) + + def test_read_with_args(self): + self._test_read(self.encoding, self.major, self.minor) + + def _test_write(self, stream_expected, major, minor): + stream_observed = BytearrayStream() + protocol_version = ProtocolVersion(major, minor) + protocol_version.write(stream_observed) + + length_expected = len(stream_expected) + length_observed = len(stream_observed) + + msg = "encoding lengths not equal" + msg += "; expected {0}, received {1}".format( + length_expected, length_observed) + self.assertEqual(length_expected, length_observed, msg) + + msg = "encoding mismatch" + msg += ";\nexpected:\n{0}\nreceived:\n{1}".format( + stream_expected, stream_observed) + + self.assertEqual(stream_expected, stream_observed, msg) + + def test_write_with_none(self): + self._test_write(self.encoding_default, self.major_default, + self.minor_default) + + def test_write_with_args(self): + self._test_write(self.encoding, self.major, self.minor) + + def test_equal_on_equal(self): + a = ProtocolVersion.create(1, 0) + b = ProtocolVersion.create(1, 0) + + self.assertTrue(a == b) + + def test_equal_on_not_equal(self): + a = ProtocolVersion.create(1, 0) + b = ProtocolVersion.create(0, 1) + + self.assertFalse(a == b) + + def test_equal_on_type_mismatch(self): + a = ProtocolVersion.create(1, 0) + b = "invalid" + + self.assertFalse(a == b) + + def test_not_equal_on_equal(self): + a = ProtocolVersion.create(1, 0) + b = ProtocolVersion.create(1, 0) + + self.assertFalse(a != b) + + def test_not_equal_on_not_equal(self): + a = ProtocolVersion.create(1, 0) + b = ProtocolVersion.create(0, 1) + + self.assertTrue(a != b) + + def test_not_equal_on_type_mismatch(self): + a = ProtocolVersion.create(1, 0) + b = "invalid" + + self.assertTrue(a != b) + + def test_repr(self): + a = ProtocolVersion.create(1, 0) + + self.assertEqual("1.0", "{0}".format(a)) + + def _test_create(self, major, minor): + protocol_version = ProtocolVersion.create(major, minor) + + if major is None: + expected = ProtocolVersion.ProtocolVersionMajor() + else: + expected = ProtocolVersion.ProtocolVersionMajor(major) + + self.assertEqual(expected, protocol_version.protocol_version_major) + + if minor is None: + expected = ProtocolVersion.ProtocolVersionMinor() + else: + expected = ProtocolVersion.ProtocolVersionMinor(minor) + + self.assertEqual(expected, protocol_version.protocol_version_minor) + + def test_create_with_none(self): + self._test_create(None, None) + + def test_create_with_args(self): + self._test_create(1, 0) diff --git a/kmip/tests/core/messages/payloads/test_discover_versions.py b/kmip/tests/core/messages/payloads/test_discover_versions.py new file mode 100644 index 0000000..e9fa456 --- /dev/null +++ b/kmip/tests/core/messages/payloads/test_discover_versions.py @@ -0,0 +1,277 @@ +# Copyright (c) 2014 The Johns Hopkins University/Applied Physics Laboratory +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from six.moves import xrange + +from testtools import TestCase + +from kmip.core import utils + +from kmip.core.messages.contents import ProtocolVersion +from kmip.core.messages.payloads import discover_versions + + +class TestDiscoverVersionsRequestPayload(TestCase): + + def setUp(self): + super(TestDiscoverVersionsRequestPayload, self).setUp() + + self.protocol_versions_empty = list() + self.protocol_versions_one = list() + self.protocol_versions_one.append(ProtocolVersion.create(1, 0)) + self.protocol_versions_two = list() + self.protocol_versions_two.append(ProtocolVersion.create(1, 1)) + self.protocol_versions_two.append(ProtocolVersion.create(1, 0)) + + self.encoding_empty = utils.BytearrayStream(( + b'\x42\x00\x79\x01\x00\x00\x00\x00')) + self.encoding_one = utils.BytearrayStream(( + b'\x42\x00\x79\x01\x00\x00\x00\x28\x42\x00\x69\x01\x00\x00\x00\x20' + b'\x42\x00\x6A\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x00' + b'\x42\x00\x6B\x02\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00' + b'\x00')) + self.encoding_two = utils.BytearrayStream(( + b'\x42\x00\x79\x01\x00\x00\x00\x50\x42\x00\x69\x01\x00\x00\x00\x20' + b'\x42\x00\x6A\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x00' + b'\x42\x00\x6B\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x00' + b'\x42\x00\x69\x01\x00\x00\x00\x20\x42\x00\x6A\x02\x00\x00\x00\x04' + b'\x00\x00\x00\x01\x00\x00\x00\x00\x42\x00\x6B\x02\x00\x00\x00\x04' + b'\x00\x00\x00\x00\x00\x00\x00\x00')) + + def tearDown(self): + super(TestDiscoverVersionsRequestPayload, self).tearDown() + + def test_init_with_none(self): + discover_versions.DiscoverVersionsRequestPayload() + + def test_init_with_args(self): + discover_versions.DiscoverVersionsRequestPayload( + self.protocol_versions_empty) + + def test_validate_with_invalid_protocol_versions(self): + kwargs = {'protocol_versions': 'invalid'} + self.assertRaisesRegexp( + TypeError, "invalid protocol versions list", + discover_versions.DiscoverVersionsRequestPayload, **kwargs) + + def test_validate_with_invalid_protocol_version(self): + kwargs = {'protocol_versions': ['invalid']} + self.assertRaisesRegexp( + TypeError, "invalid protocol version", + discover_versions.DiscoverVersionsRequestPayload, **kwargs) + + def _test_read(self, stream, payload, protocol_versions): + payload.read(stream) + expected = len(protocol_versions) + observed = len(payload.protocol_versions) + + msg = "protocol versions list decoding mismatch" + msg += "; expected {0} results, received {1}".format( + expected, observed) + self.assertEqual(expected, observed, msg) + + for i in xrange(len(protocol_versions)): + expected = protocol_versions[i] + observed = payload.protocol_versions[i] + + msg = "protocol version decoding mismatch" + msg += "; expected {0}, received {1}".format(expected, observed) + self.assertEqual(expected, observed, msg) + + def test_read_with_empty_protocol_list(self): + stream = self.encoding_empty + payload = discover_versions.DiscoverVersionsRequestPayload() + protocol_versions = self.protocol_versions_empty + + self._test_read(stream, payload, protocol_versions) + + def test_read_with_one_protocol_version(self): + stream = self.encoding_one + payload = discover_versions.DiscoverVersionsRequestPayload() + protocol_versions = self.protocol_versions_one + + self._test_read(stream, payload, protocol_versions) + + def test_read_with_two_protocol_versions(self): + stream = self.encoding_two + payload = discover_versions.DiscoverVersionsRequestPayload() + protocol_versions = self.protocol_versions_two + + self._test_read(stream, payload, protocol_versions) + + def _test_write(self, payload, expected): + stream = utils.BytearrayStream() + payload.write(stream) + + length_expected = len(expected) + length_received = len(stream) + + msg = "encoding lengths not equal" + msg += "; expected {0}, received {1}".format( + length_expected, length_received) + self.assertEqual(length_expected, length_received, msg) + + msg = "encoding mismatch" + msg += ";\nexpected:\n{0}\nreceived:\n{1}".format(expected, stream) + + self.assertEqual(expected, stream, msg) + + def test_write_with_empty_protocol_list(self): + payload = discover_versions.DiscoverVersionsRequestPayload( + self.protocol_versions_empty) + expected = self.encoding_empty + + self._test_write(payload, expected) + + def test_write_with_one_protocol_version(self): + payload = discover_versions.DiscoverVersionsRequestPayload( + self.protocol_versions_one) + expected = self.encoding_one + + self._test_write(payload, expected) + + def test_write_with_two_protocol_versions(self): + payload = discover_versions.DiscoverVersionsRequestPayload( + self.protocol_versions_two) + expected = self.encoding_two + + self._test_write(payload, expected) + + +class TestDiscoverVersionsResponsePayload(TestCase): + + def setUp(self): + super(TestDiscoverVersionsResponsePayload, self).setUp() + + self.protocol_versions_empty = list() + self.protocol_versions_one = list() + self.protocol_versions_one.append(ProtocolVersion.create(1, 0)) + self.protocol_versions_two = list() + self.protocol_versions_two.append(ProtocolVersion.create(1, 1)) + self.protocol_versions_two.append(ProtocolVersion.create(1, 0)) + + self.encoding_empty = utils.BytearrayStream(( + b'\x42\x00\x7C\x01\x00\x00\x00\x00')) + self.encoding_one = utils.BytearrayStream(( + b'\x42\x00\x7C\x01\x00\x00\x00\x28\x42\x00\x69\x01\x00\x00\x00\x20' + b'\x42\x00\x6A\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x00' + b'\x42\x00\x6B\x02\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00' + b'\x00')) + self.encoding_two = utils.BytearrayStream(( + b'\x42\x00\x7C\x01\x00\x00\x00\x50\x42\x00\x69\x01\x00\x00\x00\x20' + b'\x42\x00\x6A\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x00' + b'\x42\x00\x6B\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x00' + b'\x42\x00\x69\x01\x00\x00\x00\x20\x42\x00\x6A\x02\x00\x00\x00\x04' + b'\x00\x00\x00\x01\x00\x00\x00\x00\x42\x00\x6B\x02\x00\x00\x00\x04' + b'\x00\x00\x00\x00\x00\x00\x00\x00')) + + def tearDown(self): + super(TestDiscoverVersionsResponsePayload, self).tearDown() + + def test_init_with_none(self): + discover_versions.DiscoverVersionsResponsePayload() + + def test_init_with_args(self): + discover_versions.DiscoverVersionsResponsePayload( + self.protocol_versions_empty) + + def test_validate_with_invalid_protocol_versions(self): + kwargs = {'protocol_versions': 'invalid'} + self.assertRaisesRegexp( + TypeError, "invalid protocol versions list", + discover_versions.DiscoverVersionsResponsePayload, **kwargs) + + def test_validate_with_invalid_protocol_version(self): + kwargs = {'protocol_versions': ['invalid']} + self.assertRaisesRegexp( + TypeError, "invalid protocol version", + discover_versions.DiscoverVersionsResponsePayload, **kwargs) + + def _test_read(self, stream, payload, protocol_versions): + payload.read(stream) + expected = len(protocol_versions) + observed = len(payload.protocol_versions) + + msg = "protocol versions list decoding mismatch" + msg += "; expected {0} results, received {1}".format( + expected, observed) + self.assertEqual(expected, observed, msg) + + for i in xrange(len(protocol_versions)): + expected = protocol_versions[i] + observed = payload.protocol_versions[i] + + msg = "protocol version decoding mismatch" + msg += "; expected {0}, received {1}".format(expected, observed) + self.assertEqual(expected, observed, msg) + + def test_read_with_empty_protocol_list(self): + stream = self.encoding_empty + payload = discover_versions.DiscoverVersionsResponsePayload() + protocol_versions = self.protocol_versions_empty + + self._test_read(stream, payload, protocol_versions) + + def test_read_with_one_protocol_version(self): + stream = self.encoding_one + payload = discover_versions.DiscoverVersionsResponsePayload() + protocol_versions = self.protocol_versions_one + + self._test_read(stream, payload, protocol_versions) + + def test_read_with_two_protocol_versions(self): + stream = self.encoding_two + payload = discover_versions.DiscoverVersionsResponsePayload() + protocol_versions = self.protocol_versions_two + + self._test_read(stream, payload, protocol_versions) + + def _test_write(self, payload, expected): + stream = utils.BytearrayStream() + payload.write(stream) + + length_expected = len(expected) + length_received = len(stream) + + msg = "encoding lengths not equal" + msg += "; expected {0}, received {1}".format( + length_expected, length_received) + self.assertEqual(length_expected, length_received, msg) + + msg = "encoding mismatch" + msg += ";\nexpected:\n{0}\nreceived:\n{1}".format(expected, stream) + + self.assertEqual(expected, stream, msg) + + def test_write_with_empty_protocol_list(self): + payload = discover_versions.DiscoverVersionsResponsePayload( + self.protocol_versions_empty) + expected = self.encoding_empty + + self._test_write(payload, expected) + + def test_write_with_one_protocol_version(self): + payload = discover_versions.DiscoverVersionsResponsePayload( + self.protocol_versions_one) + expected = self.encoding_one + + self._test_write(payload, expected) + + def test_write_with_two_protocol_versions(self): + payload = discover_versions.DiscoverVersionsResponsePayload( + self.protocol_versions_two) + expected = self.encoding_two + + self._test_write(payload, expected) diff --git a/kmip/tests/services/test_kmip_client.py b/kmip/tests/services/test_kmip_client.py index 950fd3b..1b9880c 100644 --- a/kmip/tests/services/test_kmip_client.py +++ b/kmip/tests/services/test_kmip_client.py @@ -43,9 +43,14 @@ from kmip.core.factories.secrets import SecretFactory from kmip.core.messages.messages import RequestBatchItem from kmip.core.messages.messages import ResponseBatchItem from kmip.core.messages.messages import ResponseMessage + from kmip.core.messages.contents import Operation +from kmip.core.messages.contents import ProtocolVersion + from kmip.core.messages.payloads.create_key_pair import \ CreateKeyPairRequestPayload, CreateKeyPairResponsePayload +from kmip.core.messages.payloads.discover_versions import \ + DiscoverVersionsRequestPayload, DiscoverVersionsResponsePayload from kmip.core.messages.payloads.rekey_key_pair import \ RekeyKeyPairRequestPayload, RekeyKeyPairResponsePayload @@ -60,7 +65,9 @@ from kmip.core.objects import TemplateAttribute from kmip.core.secrets import SymmetricKey from kmip.services.kmip_client import KMIPProxy + from kmip.services.results import CreateKeyPairResult +from kmip.services.results import DiscoverVersionsResult from kmip.services.results import RekeyKeyPairResult import kmip.core.utils as utils @@ -587,6 +594,45 @@ class TestKMIPClient(TestCase): self._test_build_rekey_key_pair_batch_item( None, None, None, None, None) + def _test_build_discover_versions_batch_item(self, protocol_versions): + batch_item = self.client._build_discover_versions_batch_item( + protocol_versions) + + base = "expected {0}, received {1}" + msg = base.format(RequestBatchItem, batch_item) + self.assertIsInstance(batch_item, RequestBatchItem, msg) + + operation = batch_item.operation + + msg = base.format(Operation, operation) + self.assertIsInstance(operation, Operation, msg) + + operation_enum = operation.enum + + msg = base.format(OperationEnum.DISCOVER_VERSIONS, operation_enum) + self.assertEqual(OperationEnum.DISCOVER_VERSIONS, operation_enum, msg) + + payload = batch_item.request_payload + + if protocol_versions is None: + protocol_versions = list() + + msg = base.format(DiscoverVersionsRequestPayload, payload) + self.assertIsInstance(payload, DiscoverVersionsRequestPayload, msg) + + observed = payload.protocol_versions + + msg = base.format(protocol_versions, observed) + self.assertEqual(protocol_versions, observed, msg) + + def test_build_discover_versions_batch_item_with_input(self): + protocol_versions = [ProtocolVersion.create(1, 0)] + self._test_build_discover_versions_batch_item(protocol_versions) + + def test_build_discover_versions_batch_item_no_input(self): + protocol_versions = None + self._test_build_discover_versions_batch_item(protocol_versions) + def test_process_batch_items(self): batch_item = ResponseBatchItem( operation=Operation(OperationEnum.CREATE_KEY_PAIR), @@ -651,3 +697,29 @@ class TestKMIPClient(TestCase): msg = "expected {0}, received {1}".format(RekeyKeyPairResult, result) self.assertIsInstance(result, RekeyKeyPairResult, msg) + + def _test_process_discover_versions_batch_item(self, protocol_versions): + batch_item = ResponseBatchItem( + operation=Operation(OperationEnum.DISCOVER_VERSIONS), + response_payload=DiscoverVersionsResponsePayload( + protocol_versions)) + result = self.client._process_discover_versions_batch_item(batch_item) + + base = "expected {0}, received {1}" + msg = base.format(DiscoverVersionsResult, result) + self.assertIsInstance(result, DiscoverVersionsResult, msg) + + # The payload maps protocol_versions to an empty list on None + if protocol_versions is None: + protocol_versions = list() + + msg = base.format(protocol_versions, result.protocol_versions) + self.assertEqual(protocol_versions, result.protocol_versions, msg) + + def test_process_discover_versions_batch_item_with_results(self): + protocol_versions = [ProtocolVersion.create(1, 0)] + self._test_process_discover_versions_batch_item(protocol_versions) + + def test_process_discover_versions_batch_item_no_results(self): + protocol_versions = None + self._test_process_discover_versions_batch_item(protocol_versions)