Adding support for the DiscoverVersions operation

This change adds support for the DiscoverVersions operation, including
updates to the KMIP client, the client and KMIP core test suites, and a
DiscoverVersions unit demo.
This commit is contained in:
Peter Hamilton 2015-01-23 15:26:28 -05:00
parent 7ce5a74315
commit f6b420d2db
15 changed files with 923 additions and 55 deletions

View File

@ -18,6 +18,7 @@ from kmip.core.factories.payloads import PayloadFactory
from kmip.core.messages.payloads import create
from kmip.core.messages.payloads import create_key_pair
from kmip.core.messages.payloads import destroy
from kmip.core.messages.payloads import discover_versions
from kmip.core.messages.payloads import get
from kmip.core.messages.payloads import locate
from kmip.core.messages.payloads import rekey_key_pair
@ -46,3 +47,6 @@ class RequestPayloadFactory(PayloadFactory):
def _create_destroy_payload(self):
return destroy.DestroyRequestPayload()
def _create_discover_versions_payload(self):
return discover_versions.DiscoverVersionsRequestPayload()

View File

@ -18,6 +18,7 @@ from kmip.core.factories.payloads import PayloadFactory
from kmip.core.messages.payloads import create
from kmip.core.messages.payloads import create_key_pair
from kmip.core.messages.payloads import destroy
from kmip.core.messages.payloads import discover_versions
from kmip.core.messages.payloads import get
from kmip.core.messages.payloads import locate
from kmip.core.messages.payloads import rekey_key_pair
@ -46,3 +47,6 @@ class ResponsePayloadFactory(PayloadFactory):
def _create_destroy_payload(self):
return destroy.DestroyResponsePayload()
def _create_discover_versions_payload(self):
return discover_versions.DiscoverVersionsResponsePayload()

View File

@ -13,7 +13,14 @@
# License for the specific language governing permissions and limitations
# under the License.
from kmip.core import enums
from kmip.core.enums import BatchErrorContinuationOption
from kmip.core.enums import KeyCompressionType
from kmip.core.enums import KeyFormatType
from kmip.core.enums import Operation
from kmip.core.enums import ResultStatus
from kmip.core.enums import ResultReason
from kmip.core.enums import Tags
from kmip.core import objects
from kmip.core import utils
@ -31,28 +38,38 @@ class ProtocolVersion(Struct):
class ProtocolVersionMajor(Integer):
def __init__(self, value=None):
super(self.__class__, self).\
__init__(value, enums.Tags.PROTOCOL_VERSION_MAJOR)
super(ProtocolVersion.ProtocolVersionMajor, self).\
__init__(value, Tags.PROTOCOL_VERSION_MAJOR)
class ProtocolVersionMinor(Integer):
def __init__(self, value=None):
super(self.__class__, self).\
__init__(value, enums.Tags.PROTOCOL_VERSION_MINOR)
super(ProtocolVersion.ProtocolVersionMinor, self).\
__init__(value, Tags.PROTOCOL_VERSION_MINOR)
def __init__(self,
protocol_version_major=None,
protocol_version_minor=None):
super(self.__class__, self).__init__(tag=enums.Tags.PROTOCOL_VERSION)
super(ProtocolVersion, self).__init__(Tags.PROTOCOL_VERSION)
if protocol_version_major is None:
self.protocol_version_major = \
ProtocolVersion.ProtocolVersionMajor()
else:
self.protocol_version_major = protocol_version_major
if protocol_version_minor is None:
self.protocol_version_minor = \
ProtocolVersion.ProtocolVersionMinor()
else:
self.protocol_version_minor = protocol_version_minor
self.validate()
def read(self, istream):
super(self.__class__, self).read(istream)
super(ProtocolVersion, self).read(istream)
tstream = utils.BytearrayStream(istream.read(self.length))
# Read the major and minor portions of the version number
self.protocol_version_major = ProtocolVersion.ProtocolVersionMajor()
self.protocol_version_minor = ProtocolVersion.ProtocolVersionMinor()
self.protocol_version_major.read(tstream)
self.protocol_version_minor.read(tstream)
@ -67,57 +84,97 @@ class ProtocolVersion(Struct):
# Write the length and value of the protocol version
self.length = tstream.length()
super(self.__class__, self).write(ostream)
super(ProtocolVersion, self).write(ostream)
ostream.write(tstream.buffer)
def validate(self):
# TODO (peter-hamilton) Finish implementation.
pass
self.__validate()
def __validate(self):
if not isinstance(self.protocol_version_major,
ProtocolVersion.ProtocolVersionMajor):
msg = "invalid protocol version major"
msg += "; expected {0}, received {1}".format(
ProtocolVersion.ProtocolVersionMajor,
self.protocol_version_major)
raise TypeError(msg)
if not isinstance(self.protocol_version_minor,
ProtocolVersion.ProtocolVersionMinor):
msg = "invalid protocol version minor"
msg += "; expected {0}, received {1}".format(
ProtocolVersion.ProtocolVersionMinor,
self.protocol_version_minor)
raise TypeError(msg)
def __eq__(self, other):
if isinstance(other, ProtocolVersion):
if ((self.protocol_version_major ==
other.protocol_version_major)
and
(self.protocol_version_minor ==
other.protocol_version_minor)):
return True
else:
return False
else:
return NotImplemented
def __ne__(self, other):
if isinstance(other, ProtocolVersion):
return not self.__eq__(other)
else:
return NotImplemented
def __repr__(self):
major = self.protocol_version_major.value
minor = self.protocol_version_minor.value
return "{0}.{1}".format(major, minor)
@classmethod
def create(cls, major, minor):
major_version = cls.ProtocolVersionMajor(major)
minor_version = cls.ProtocolVersionMinor(minor)
return ProtocolVersion(major_version, minor_version)
major = cls.ProtocolVersionMajor(major)
minor = cls.ProtocolVersionMinor(minor)
return ProtocolVersion(major, minor)
# 6.2
class Operation(Enumeration):
ENUM_TYPE = enums.Operation
ENUM_TYPE = Operation
def __init__(self, value=None):
super(self.__class__, self).__init__(value, enums.Tags.OPERATION)
super(Operation, self).__init__(value, Tags.OPERATION)
# 6.3
class MaximumResponseSize(Integer):
def __init__(self, value=None):
super(self.__class__, self).\
__init__(value, enums.Tags.MAXIMUM_RESPONSE_SIZE)
super(MaximumResponseSize, self).\
__init__(value, Tags.MAXIMUM_RESPONSE_SIZE)
# 6.4
class UniqueBatchItemID(ByteString):
def __init__(self, value=None):
super(self.__class__, self)\
.__init__(value, enums.Tags.UNIQUE_BATCH_ITEM_ID)
super(UniqueBatchItemID, self)\
.__init__(value, Tags.UNIQUE_BATCH_ITEM_ID)
# 6.5
class TimeStamp(DateTime):
def __init__(self, value=None):
super(self.__class__, self).__init__(value, enums.Tags.TIME_STAMP)
super(TimeStamp, self).__init__(value, Tags.TIME_STAMP)
# 6.6
class Authentication(Struct):
def __init__(self, credential=None):
super(self.__class__, self).__init__(tag=enums.Tags.AUTHENTICATION)
super(Authentication, self).__init__(Tags.AUTHENTICATION)
self.credential = credential
def read(self, istream):
super(self.__class__, self).read(istream)
super(Authentication, self).read(istream)
tstream = utils.BytearrayStream(istream.read(self.length))
# Read the credential
@ -134,7 +191,7 @@ class Authentication(Struct):
# Write the length and value of the protocol version
self.length = tstream.length()
super(self.__class__, self).write(ostream)
super(Authentication, self).write(ostream)
ostream.write(tstream.buffer)
def validate(self):
@ -145,80 +202,80 @@ class Authentication(Struct):
# 6.7
class AsynchronousIndicator(Boolean):
def __init__(self, value=None):
super(self.__class__, self).\
__init__(value, enums.Tags.ASYNCHRONOUS_INDICATOR)
super(AsynchronousIndicator, self).\
__init__(value, Tags.ASYNCHRONOUS_INDICATOR)
# 6.8
class AsynchronousCorrelationValue(ByteString):
def __init__(self, value=None):
super(self.__class__, self).\
__init__(value, enums.Tags.ASYNCHRONOUS_CORRELATION_VALUE)
super(AsynchronousCorrelationValue, self).\
__init__(value, Tags.ASYNCHRONOUS_CORRELATION_VALUE)
# 6.9
class ResultStatus(Enumeration):
ENUM_TYPE = enums.ResultStatus
ENUM_TYPE = ResultStatus
def __init__(self, value=None):
super(self.__class__, self).__init__(value, enums.Tags.RESULT_STATUS)
super(ResultStatus, self).__init__(value, Tags.RESULT_STATUS)
# 6.10
class ResultReason(Enumeration):
ENUM_TYPE = enums.ResultReason
ENUM_TYPE = ResultReason
def __init__(self, value=None):
super(self.__class__, self).__init__(value, enums.Tags.RESULT_REASON)
super(ResultReason, self).__init__(value, Tags.RESULT_REASON)
# 6.11
class ResultMessage(TextString):
def __init__(self, value=None):
super(self.__class__, self).__init__(value, enums.Tags.RESULT_MESSAGE)
super(ResultMessage, self).__init__(value, Tags.RESULT_MESSAGE)
# 6.12
class BatchOrderOption(Boolean):
def __init__(self, value=None):
super(self.__class__, self).\
__init__(value, enums.Tags.BATCH_ORDER_OPTION)
super(BatchOrderOption, self).\
__init__(value, Tags.BATCH_ORDER_OPTION)
# 6.13
class BatchErrorContinuationOption(Enumeration):
ENUM_TYPE = enums.BatchErrorContinuationOption
ENUM_TYPE = BatchErrorContinuationOption
def __init__(self, value=None):
super(self.__class__, self).\
__init__(value, enums.Tags.BATCH_ERROR_CONTINUATION_OPTION)
super(BatchErrorContinuationOption, self).\
__init__(value, Tags.BATCH_ERROR_CONTINUATION_OPTION)
# 6.14
class BatchCount(Integer):
def __init__(self, value=None):
super(self.__class__, self).__init__(value, enums.Tags.BATCH_COUNT)
super(BatchCount, self).__init__(value, Tags.BATCH_COUNT)
# 6.16
class MessageExtension(Struct):
def __init__(self):
super(self.__class__, self).__init__(tag=enums.Tags.MESSAGE_EXTENSION)
super(MessageExtension, self).__init__(Tags.MESSAGE_EXTENSION)
# 9.1.3.2.2
class KeyCompressionType(Enumeration):
ENUM_TYPE = enums.KeyCompressionType
ENUM_TYPE = KeyCompressionType
def __init__(self, value=None):
super(self.__class__, self).\
__init__(value, enums.Tags.KEY_COMPRESSION_TYPE)
super(KeyCompressionType, self).\
__init__(value, Tags.KEY_COMPRESSION_TYPE)
# 9.1.3.2.3
class KeyFormatType(Enumeration):
ENUM_TYPE = enums.KeyFormatType
ENUM_TYPE = KeyFormatType
def __init__(self, value=None):
super(self.__class__, self).\
__init__(value, enums.Tags.KEY_FORMAT_TYPE)
super(KeyFormatType, self).\
__init__(value, Tags.KEY_FORMAT_TYPE)

View File

@ -0,0 +1,132 @@
# Copyright (c) 2015 The Johns Hopkins University/Applied Physics Laboratory
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from six.moves import xrange
from kmip.core.enums import Tags
from kmip.core.messages.contents import ProtocolVersion
from kmip.core.primitives import Struct
from kmip.core.utils import BytearrayStream
class DiscoverVersionsRequestPayload(Struct):
def __init__(self, protocol_versions=None):
super(DiscoverVersionsRequestPayload, self).__init__(
Tags.REQUEST_PAYLOAD)
if protocol_versions is None:
self.protocol_versions = list()
else:
self.protocol_versions = protocol_versions
self.validate()
def read(self, istream):
super(DiscoverVersionsRequestPayload, self).read(istream)
tstream = BytearrayStream(istream.read(self.length))
while(self.is_tag_next(Tags.PROTOCOL_VERSION, tstream)):
protocol_version = ProtocolVersion()
protocol_version.read(tstream)
self.protocol_versions.append(protocol_version)
self.is_oversized(tstream)
self.validate()
def write(self, ostream):
tstream = BytearrayStream()
for protocol_version in self.protocol_versions:
protocol_version.write(tstream)
self.length = tstream.length()
super(DiscoverVersionsRequestPayload, self).write(ostream)
ostream.write(tstream.buffer)
def validate(self):
self.__validate()
def __validate(self):
if isinstance(self.protocol_versions, list):
for i in xrange(len(self.protocol_versions)):
protocol_version = self.protocol_versions[i]
if not isinstance(protocol_version, ProtocolVersion):
msg = "invalid protocol version ({0} in list)".format(i)
msg += "; expected {0}, received {1}".format(
ProtocolVersion, protocol_version)
raise TypeError(msg)
else:
msg = "invalid protocol versions list"
msg += "; expected {0}, received {1}".format(
list, self.protocol_versions)
raise TypeError(msg)
class DiscoverVersionsResponsePayload(Struct):
def __init__(self, protocol_versions=None):
super(DiscoverVersionsResponsePayload, self).__init__(
Tags.RESPONSE_PAYLOAD)
if protocol_versions is None:
self.protocol_versions = list()
else:
self.protocol_versions = protocol_versions
self.validate()
def read(self, istream):
super(DiscoverVersionsResponsePayload, self).read(istream)
tstream = BytearrayStream(istream.read(self.length))
while(self.is_tag_next(Tags.PROTOCOL_VERSION, tstream)):
protocol_version = ProtocolVersion()
protocol_version.read(tstream)
self.protocol_versions.append(protocol_version)
self.is_oversized(tstream)
self.validate()
def write(self, ostream):
tstream = BytearrayStream()
for protocol_version in self.protocol_versions:
protocol_version.write(tstream)
self.length = tstream.length()
super(DiscoverVersionsResponsePayload, self).write(ostream)
ostream.write(tstream.buffer)
def validate(self):
self.__validate()
def __validate(self):
if isinstance(self.protocol_versions, list):
for i in xrange(len(self.protocol_versions)):
protocol_version = self.protocol_versions[i]
if not isinstance(protocol_version, ProtocolVersion):
msg = "invalid protocol version ({0} in list)".format(i)
msg += "; expected {0}, received {1}".format(
ProtocolVersion, protocol_version)
raise TypeError(msg)
else:
msg = "invalid protocol versions list"
msg += "; expected {0}, received {1}".format(
list, self.protocol_versions)
raise TypeError(msg)

View File

@ -233,7 +233,10 @@ class Integer(Base):
return NotImplemented
def __ne__(self, other):
if isinstance(other, Integer):
return not self.__eq__(other)
else:
return NotImplemented
class LongInteger(Base):

View File

@ -0,0 +1,67 @@
# Copyright (c) 2015 The Johns Hopkins University/Applied Physics Laboratory
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from kmip.core.enums import Operation
from kmip.core.enums import ResultStatus
from kmip.demos import utils
from kmip.services.kmip_client import KMIPProxy
import logging
import os
import sys
if __name__ == '__main__':
# Build and parse arguments
parser = utils.build_cli_parser(Operation.DISCOVER_VERSIONS)
opts, args = parser.parse_args(sys.argv[1:])
username = opts.username
password = opts.password
# Build and setup logging and needed factories
f_log = os.path.join(os.path.dirname(__file__), os.pardir, os.pardir,
'logconfig.ini')
logging.config.fileConfig(f_log)
logger = logging.getLogger(__name__)
# Build the client and connect to the server
client = KMIPProxy()
client.open()
result = client.discover_versions()
client.close()
# Display operation results
logger.debug('discover_versions() result status: {}'.format(
result.result_status.enum))
if result.result_status.enum == ResultStatus.SUCCESS:
protocol_versions = result.protocol_versions
if isinstance(protocol_versions, list):
logger.debug('number of protocol versions returned: {0}'.format(
len(protocol_versions)))
for protocol_version in protocol_versions:
logging.debug('protocol version supported: {0}'.format(
protocol_version))
else:
logger.debug('number of protocol versions returned: 0')
else:
logger.debug('discover_versions() result reason: {}'.format(
result.result_reason.enum))
logger.debug('discover_versions() result message: {}'.format(
result.result_message.value))

View File

@ -119,6 +119,8 @@ def build_cli_parser(operation):
default=None,
dest="length",
help="Key length in bits (e.g., 128, 256)")
elif operation is Operation.DISCOVER_VERSIONS:
pass
else:
raise ValueError("unrecognized operation: {0}".format(operation))

View File

@ -16,6 +16,7 @@
from kmip.services.results import CreateResult
from kmip.services.results import CreateKeyPairResult
from kmip.services.results import DestroyResult
from kmip.services.results import DiscoverVersionsResult
from kmip.services.results import GetResult
from kmip.services.results import LocateResult
from kmip.services.results import RegisterResult
@ -41,6 +42,7 @@ from kmip.core.messages import messages
from kmip.core.messages.payloads import create
from kmip.core.messages.payloads import create_key_pair
from kmip.core.messages.payloads import destroy
from kmip.core.messages.payloads import discover_versions
from kmip.core.messages.payloads import get
from kmip.core.messages.payloads import locate
from kmip.core.messages.payloads import rekey_key_pair
@ -174,6 +176,19 @@ class KMIPProxy(KMIP):
object_group_member=object_group_member,
attributes=attributes, credential=credential)
def discover_versions(self, batch=False, protocol_versions=None,
credential=None):
batch_item = self._build_discover_versions_batch_item(
protocol_versions)
if batch:
self.batch_items.append(batch_item)
else:
request = self._build_request_message(credential, [batch_item])
response = self._send_and_receive_message(request)
results = self._process_batch_items(response)
return results[0]
def _create(self,
object_type=None,
template_attribute=None,
@ -242,6 +257,16 @@ class KMIPProxy(KMIP):
operation=operation, request_payload=payload)
return batch_item
def _build_discover_versions_batch_item(self, protocol_versions=None):
operation = Operation(OperationEnum.DISCOVER_VERSIONS)
payload = discover_versions.DiscoverVersionsRequestPayload(
protocol_versions)
batch_item = messages.RequestBatchItem(
operation=operation, request_payload=payload)
return batch_item
def _process_batch_items(self, response):
results = []
for batch_item in response.batch_items:
@ -256,6 +281,8 @@ class KMIPProxy(KMIP):
return self._process_create_key_pair_batch_item
elif operation == OperationEnum.REKEY_KEY_PAIR:
return self._process_rekey_key_pair_batch_item
elif operation == OperationEnum.DISCOVER_VERSIONS:
return self._process_discover_versions_batch_item
else:
raise ValueError("no processor for operation: {0}".format(
operation))
@ -290,6 +317,15 @@ class KMIPProxy(KMIP):
return self._process_key_pair_batch_item(
batch_item, RekeyKeyPairResult)
def _process_discover_versions_batch_item(self, batch_item):
payload = batch_item.response_payload
result = DiscoverVersionsResult(
batch_item.result_status, batch_item.result_reason,
batch_item.result_message, payload.protocol_versions)
return result
def _get(self,
unique_identifier=None,
key_format_type=None,

View File

@ -172,3 +172,15 @@ class LocateResult(OperationResult):
result_reason,
result_message)
self.uuids = uuids
class DiscoverVersionsResult(OperationResult):
def __init__(self,
result_status,
result_reason=None,
result_message=None,
protocol_versions=None):
super(DiscoverVersionsResult, self).__init__(
result_status, result_reason, result_message)
self.protocol_versions = protocol_versions

View File

@ -21,6 +21,7 @@ from kmip.core.factories.payloads.request import RequestPayloadFactory
from kmip.core.messages.payloads import create
from kmip.core.messages.payloads import create_key_pair
from kmip.core.messages.payloads import destroy
from kmip.core.messages.payloads import discover_versions
from kmip.core.messages.payloads import get
from kmip.core.messages.payloads import locate
from kmip.core.messages.payloads import rekey_key_pair
@ -162,5 +163,6 @@ class TestRequestPayloadFactory(testtools.TestCase):
payload, rekey_key_pair.RekeyKeyPairRequestPayload)
def test_create_discover_versions_payload(self):
self._test_not_implemented(
self.factory.create, Operation.DISCOVER_VERSIONS)
payload = self.factory.create(Operation.DISCOVER_VERSIONS)
self._test_payload_type(
payload, discover_versions.DiscoverVersionsRequestPayload)

View File

@ -21,6 +21,7 @@ from kmip.core.factories.payloads.response import ResponsePayloadFactory
from kmip.core.messages.payloads import create
from kmip.core.messages.payloads import create_key_pair
from kmip.core.messages.payloads import destroy
from kmip.core.messages.payloads import discover_versions
from kmip.core.messages.payloads import get
from kmip.core.messages.payloads import locate
from kmip.core.messages.payloads import rekey_key_pair
@ -162,5 +163,6 @@ class TestResponsePayloadFactory(testtools.TestCase):
payload, rekey_key_pair.RekeyKeyPairResponsePayload)
def test_create_discover_versions_payload(self):
self._test_not_implemented(
self.factory.create, Operation.DISCOVER_VERSIONS)
payload = self.factory.create(Operation.DISCOVER_VERSIONS)
self._test_payload_type(
payload, discover_versions.DiscoverVersionsResponsePayload)

View File

@ -1,4 +1,4 @@
# Copyright (c) 2014 The Johns Hopkins University/Applied Physics Laboratory
# Copyright (c) 2015 The Johns Hopkins University/Applied Physics Laboratory
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may

View File

@ -0,0 +1,198 @@
# Copyright (c) 2014 The Johns Hopkins University/Applied Physics Laboratory
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from testtools import TestCase
from kmip.core.messages.contents import ProtocolVersion
from kmip.core.utils import BytearrayStream
class TestProtocolVersion(TestCase):
def setUp(self):
super(TestProtocolVersion, self).setUp()
self.major_default = ProtocolVersion.ProtocolVersionMajor()
self.minor_default = ProtocolVersion.ProtocolVersionMinor()
self.major = ProtocolVersion.ProtocolVersionMajor(1)
self.minor = ProtocolVersion.ProtocolVersionMinor(1)
self.encoding_default = BytearrayStream((
b'\x42\x00\x69\x01\x00\x00\x00\x20\x42\x00\x6A\x02\x00\x00\x00\x04'
b'\x00\x00\x00\x00\x00\x00\x00\x00\x42\x00\x6B\x02\x00\x00\x00\x04'
b'\x00\x00\x00\x00\x00\x00\x00\x00'))
self.encoding = BytearrayStream((
b'\x42\x00\x69\x01\x00\x00\x00\x20\x42\x00\x6A\x02\x00\x00\x00\x04'
b'\x00\x00\x00\x01\x00\x00\x00\x00\x42\x00\x6B\x02\x00\x00\x00\x04'
b'\x00\x00\x00\x01\x00\x00\x00\x00'))
def tearDown(self):
super(TestProtocolVersion, self).tearDown()
def _test_init(self, protocol_version_major, protocol_version_minor):
protocol_version = ProtocolVersion(
protocol_version_major, protocol_version_minor)
if protocol_version_major is None:
self.assertEqual(ProtocolVersion.ProtocolVersionMajor(),
protocol_version.protocol_version_major)
else:
self.assertEqual(protocol_version_major,
protocol_version.protocol_version_major)
if protocol_version_minor is None:
self.assertEqual(ProtocolVersion.ProtocolVersionMinor(),
protocol_version.protocol_version_minor)
else:
self.assertEqual(protocol_version_minor,
protocol_version.protocol_version_minor)
def test_init_with_none(self):
self._test_init(None, None)
def test_init_with_args(self):
major = ProtocolVersion.ProtocolVersionMajor(1)
minor = ProtocolVersion.ProtocolVersionMinor(0)
self._test_init(major, minor)
def test_validate_on_invalid_protocol_version_major(self):
major = "invalid"
minor = ProtocolVersion.ProtocolVersionMinor(0)
args = [major, minor]
self.assertRaisesRegexp(
TypeError, "invalid protocol version major", self._test_init,
*args)
def test_validate_on_invalid_protocol_version_minor(self):
major = ProtocolVersion.ProtocolVersionMajor(1)
minor = "invalid"
args = [major, minor]
self.assertRaisesRegexp(
TypeError, "invalid protocol version minor", self._test_init,
*args)
def _test_read(self, stream, major, minor):
protocol_version = ProtocolVersion()
protocol_version.read(stream)
msg = "protocol version major decoding mismatch"
msg += "; expected {0}, received {1}".format(
major, protocol_version.protocol_version_major)
self.assertEqual(major, protocol_version.protocol_version_major, msg)
msg = "protocol version minor decoding mismatch"
msg += "; expected {0}, received {1}".format(
minor, protocol_version.protocol_version_minor)
self.assertEqual(minor, protocol_version.protocol_version_minor, msg)
def test_read_with_none(self):
self._test_read(self.encoding_default, self.major_default,
self.minor_default)
def test_read_with_args(self):
self._test_read(self.encoding, self.major, self.minor)
def _test_write(self, stream_expected, major, minor):
stream_observed = BytearrayStream()
protocol_version = ProtocolVersion(major, minor)
protocol_version.write(stream_observed)
length_expected = len(stream_expected)
length_observed = len(stream_observed)
msg = "encoding lengths not equal"
msg += "; expected {0}, received {1}".format(
length_expected, length_observed)
self.assertEqual(length_expected, length_observed, msg)
msg = "encoding mismatch"
msg += ";\nexpected:\n{0}\nreceived:\n{1}".format(
stream_expected, stream_observed)
self.assertEqual(stream_expected, stream_observed, msg)
def test_write_with_none(self):
self._test_write(self.encoding_default, self.major_default,
self.minor_default)
def test_write_with_args(self):
self._test_write(self.encoding, self.major, self.minor)
def test_equal_on_equal(self):
a = ProtocolVersion.create(1, 0)
b = ProtocolVersion.create(1, 0)
self.assertTrue(a == b)
def test_equal_on_not_equal(self):
a = ProtocolVersion.create(1, 0)
b = ProtocolVersion.create(0, 1)
self.assertFalse(a == b)
def test_equal_on_type_mismatch(self):
a = ProtocolVersion.create(1, 0)
b = "invalid"
self.assertFalse(a == b)
def test_not_equal_on_equal(self):
a = ProtocolVersion.create(1, 0)
b = ProtocolVersion.create(1, 0)
self.assertFalse(a != b)
def test_not_equal_on_not_equal(self):
a = ProtocolVersion.create(1, 0)
b = ProtocolVersion.create(0, 1)
self.assertTrue(a != b)
def test_not_equal_on_type_mismatch(self):
a = ProtocolVersion.create(1, 0)
b = "invalid"
self.assertTrue(a != b)
def test_repr(self):
a = ProtocolVersion.create(1, 0)
self.assertEqual("1.0", "{0}".format(a))
def _test_create(self, major, minor):
protocol_version = ProtocolVersion.create(major, minor)
if major is None:
expected = ProtocolVersion.ProtocolVersionMajor()
else:
expected = ProtocolVersion.ProtocolVersionMajor(major)
self.assertEqual(expected, protocol_version.protocol_version_major)
if minor is None:
expected = ProtocolVersion.ProtocolVersionMinor()
else:
expected = ProtocolVersion.ProtocolVersionMinor(minor)
self.assertEqual(expected, protocol_version.protocol_version_minor)
def test_create_with_none(self):
self._test_create(None, None)
def test_create_with_args(self):
self._test_create(1, 0)

View File

@ -0,0 +1,277 @@
# Copyright (c) 2014 The Johns Hopkins University/Applied Physics Laboratory
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from six.moves import xrange
from testtools import TestCase
from kmip.core import utils
from kmip.core.messages.contents import ProtocolVersion
from kmip.core.messages.payloads import discover_versions
class TestDiscoverVersionsRequestPayload(TestCase):
def setUp(self):
super(TestDiscoverVersionsRequestPayload, self).setUp()
self.protocol_versions_empty = list()
self.protocol_versions_one = list()
self.protocol_versions_one.append(ProtocolVersion.create(1, 0))
self.protocol_versions_two = list()
self.protocol_versions_two.append(ProtocolVersion.create(1, 1))
self.protocol_versions_two.append(ProtocolVersion.create(1, 0))
self.encoding_empty = utils.BytearrayStream((
b'\x42\x00\x79\x01\x00\x00\x00\x00'))
self.encoding_one = utils.BytearrayStream((
b'\x42\x00\x79\x01\x00\x00\x00\x28\x42\x00\x69\x01\x00\x00\x00\x20'
b'\x42\x00\x6A\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x00'
b'\x42\x00\x6B\x02\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00'
b'\x00'))
self.encoding_two = utils.BytearrayStream((
b'\x42\x00\x79\x01\x00\x00\x00\x50\x42\x00\x69\x01\x00\x00\x00\x20'
b'\x42\x00\x6A\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x00'
b'\x42\x00\x6B\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x00'
b'\x42\x00\x69\x01\x00\x00\x00\x20\x42\x00\x6A\x02\x00\x00\x00\x04'
b'\x00\x00\x00\x01\x00\x00\x00\x00\x42\x00\x6B\x02\x00\x00\x00\x04'
b'\x00\x00\x00\x00\x00\x00\x00\x00'))
def tearDown(self):
super(TestDiscoverVersionsRequestPayload, self).tearDown()
def test_init_with_none(self):
discover_versions.DiscoverVersionsRequestPayload()
def test_init_with_args(self):
discover_versions.DiscoverVersionsRequestPayload(
self.protocol_versions_empty)
def test_validate_with_invalid_protocol_versions(self):
kwargs = {'protocol_versions': 'invalid'}
self.assertRaisesRegexp(
TypeError, "invalid protocol versions list",
discover_versions.DiscoverVersionsRequestPayload, **kwargs)
def test_validate_with_invalid_protocol_version(self):
kwargs = {'protocol_versions': ['invalid']}
self.assertRaisesRegexp(
TypeError, "invalid protocol version",
discover_versions.DiscoverVersionsRequestPayload, **kwargs)
def _test_read(self, stream, payload, protocol_versions):
payload.read(stream)
expected = len(protocol_versions)
observed = len(payload.protocol_versions)
msg = "protocol versions list decoding mismatch"
msg += "; expected {0} results, received {1}".format(
expected, observed)
self.assertEqual(expected, observed, msg)
for i in xrange(len(protocol_versions)):
expected = protocol_versions[i]
observed = payload.protocol_versions[i]
msg = "protocol version decoding mismatch"
msg += "; expected {0}, received {1}".format(expected, observed)
self.assertEqual(expected, observed, msg)
def test_read_with_empty_protocol_list(self):
stream = self.encoding_empty
payload = discover_versions.DiscoverVersionsRequestPayload()
protocol_versions = self.protocol_versions_empty
self._test_read(stream, payload, protocol_versions)
def test_read_with_one_protocol_version(self):
stream = self.encoding_one
payload = discover_versions.DiscoverVersionsRequestPayload()
protocol_versions = self.protocol_versions_one
self._test_read(stream, payload, protocol_versions)
def test_read_with_two_protocol_versions(self):
stream = self.encoding_two
payload = discover_versions.DiscoverVersionsRequestPayload()
protocol_versions = self.protocol_versions_two
self._test_read(stream, payload, protocol_versions)
def _test_write(self, payload, expected):
stream = utils.BytearrayStream()
payload.write(stream)
length_expected = len(expected)
length_received = len(stream)
msg = "encoding lengths not equal"
msg += "; expected {0}, received {1}".format(
length_expected, length_received)
self.assertEqual(length_expected, length_received, msg)
msg = "encoding mismatch"
msg += ";\nexpected:\n{0}\nreceived:\n{1}".format(expected, stream)
self.assertEqual(expected, stream, msg)
def test_write_with_empty_protocol_list(self):
payload = discover_versions.DiscoverVersionsRequestPayload(
self.protocol_versions_empty)
expected = self.encoding_empty
self._test_write(payload, expected)
def test_write_with_one_protocol_version(self):
payload = discover_versions.DiscoverVersionsRequestPayload(
self.protocol_versions_one)
expected = self.encoding_one
self._test_write(payload, expected)
def test_write_with_two_protocol_versions(self):
payload = discover_versions.DiscoverVersionsRequestPayload(
self.protocol_versions_two)
expected = self.encoding_two
self._test_write(payload, expected)
class TestDiscoverVersionsResponsePayload(TestCase):
def setUp(self):
super(TestDiscoverVersionsResponsePayload, self).setUp()
self.protocol_versions_empty = list()
self.protocol_versions_one = list()
self.protocol_versions_one.append(ProtocolVersion.create(1, 0))
self.protocol_versions_two = list()
self.protocol_versions_two.append(ProtocolVersion.create(1, 1))
self.protocol_versions_two.append(ProtocolVersion.create(1, 0))
self.encoding_empty = utils.BytearrayStream((
b'\x42\x00\x7C\x01\x00\x00\x00\x00'))
self.encoding_one = utils.BytearrayStream((
b'\x42\x00\x7C\x01\x00\x00\x00\x28\x42\x00\x69\x01\x00\x00\x00\x20'
b'\x42\x00\x6A\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x00'
b'\x42\x00\x6B\x02\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00'
b'\x00'))
self.encoding_two = utils.BytearrayStream((
b'\x42\x00\x7C\x01\x00\x00\x00\x50\x42\x00\x69\x01\x00\x00\x00\x20'
b'\x42\x00\x6A\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x00'
b'\x42\x00\x6B\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x00'
b'\x42\x00\x69\x01\x00\x00\x00\x20\x42\x00\x6A\x02\x00\x00\x00\x04'
b'\x00\x00\x00\x01\x00\x00\x00\x00\x42\x00\x6B\x02\x00\x00\x00\x04'
b'\x00\x00\x00\x00\x00\x00\x00\x00'))
def tearDown(self):
super(TestDiscoverVersionsResponsePayload, self).tearDown()
def test_init_with_none(self):
discover_versions.DiscoverVersionsResponsePayload()
def test_init_with_args(self):
discover_versions.DiscoverVersionsResponsePayload(
self.protocol_versions_empty)
def test_validate_with_invalid_protocol_versions(self):
kwargs = {'protocol_versions': 'invalid'}
self.assertRaisesRegexp(
TypeError, "invalid protocol versions list",
discover_versions.DiscoverVersionsResponsePayload, **kwargs)
def test_validate_with_invalid_protocol_version(self):
kwargs = {'protocol_versions': ['invalid']}
self.assertRaisesRegexp(
TypeError, "invalid protocol version",
discover_versions.DiscoverVersionsResponsePayload, **kwargs)
def _test_read(self, stream, payload, protocol_versions):
payload.read(stream)
expected = len(protocol_versions)
observed = len(payload.protocol_versions)
msg = "protocol versions list decoding mismatch"
msg += "; expected {0} results, received {1}".format(
expected, observed)
self.assertEqual(expected, observed, msg)
for i in xrange(len(protocol_versions)):
expected = protocol_versions[i]
observed = payload.protocol_versions[i]
msg = "protocol version decoding mismatch"
msg += "; expected {0}, received {1}".format(expected, observed)
self.assertEqual(expected, observed, msg)
def test_read_with_empty_protocol_list(self):
stream = self.encoding_empty
payload = discover_versions.DiscoverVersionsResponsePayload()
protocol_versions = self.protocol_versions_empty
self._test_read(stream, payload, protocol_versions)
def test_read_with_one_protocol_version(self):
stream = self.encoding_one
payload = discover_versions.DiscoverVersionsResponsePayload()
protocol_versions = self.protocol_versions_one
self._test_read(stream, payload, protocol_versions)
def test_read_with_two_protocol_versions(self):
stream = self.encoding_two
payload = discover_versions.DiscoverVersionsResponsePayload()
protocol_versions = self.protocol_versions_two
self._test_read(stream, payload, protocol_versions)
def _test_write(self, payload, expected):
stream = utils.BytearrayStream()
payload.write(stream)
length_expected = len(expected)
length_received = len(stream)
msg = "encoding lengths not equal"
msg += "; expected {0}, received {1}".format(
length_expected, length_received)
self.assertEqual(length_expected, length_received, msg)
msg = "encoding mismatch"
msg += ";\nexpected:\n{0}\nreceived:\n{1}".format(expected, stream)
self.assertEqual(expected, stream, msg)
def test_write_with_empty_protocol_list(self):
payload = discover_versions.DiscoverVersionsResponsePayload(
self.protocol_versions_empty)
expected = self.encoding_empty
self._test_write(payload, expected)
def test_write_with_one_protocol_version(self):
payload = discover_versions.DiscoverVersionsResponsePayload(
self.protocol_versions_one)
expected = self.encoding_one
self._test_write(payload, expected)
def test_write_with_two_protocol_versions(self):
payload = discover_versions.DiscoverVersionsResponsePayload(
self.protocol_versions_two)
expected = self.encoding_two
self._test_write(payload, expected)

View File

@ -43,9 +43,14 @@ from kmip.core.factories.secrets import SecretFactory
from kmip.core.messages.messages import RequestBatchItem
from kmip.core.messages.messages import ResponseBatchItem
from kmip.core.messages.messages import ResponseMessage
from kmip.core.messages.contents import Operation
from kmip.core.messages.contents import ProtocolVersion
from kmip.core.messages.payloads.create_key_pair import \
CreateKeyPairRequestPayload, CreateKeyPairResponsePayload
from kmip.core.messages.payloads.discover_versions import \
DiscoverVersionsRequestPayload, DiscoverVersionsResponsePayload
from kmip.core.messages.payloads.rekey_key_pair import \
RekeyKeyPairRequestPayload, RekeyKeyPairResponsePayload
@ -60,7 +65,9 @@ from kmip.core.objects import TemplateAttribute
from kmip.core.secrets import SymmetricKey
from kmip.services.kmip_client import KMIPProxy
from kmip.services.results import CreateKeyPairResult
from kmip.services.results import DiscoverVersionsResult
from kmip.services.results import RekeyKeyPairResult
import kmip.core.utils as utils
@ -587,6 +594,45 @@ class TestKMIPClient(TestCase):
self._test_build_rekey_key_pair_batch_item(
None, None, None, None, None)
def _test_build_discover_versions_batch_item(self, protocol_versions):
batch_item = self.client._build_discover_versions_batch_item(
protocol_versions)
base = "expected {0}, received {1}"
msg = base.format(RequestBatchItem, batch_item)
self.assertIsInstance(batch_item, RequestBatchItem, msg)
operation = batch_item.operation
msg = base.format(Operation, operation)
self.assertIsInstance(operation, Operation, msg)
operation_enum = operation.enum
msg = base.format(OperationEnum.DISCOVER_VERSIONS, operation_enum)
self.assertEqual(OperationEnum.DISCOVER_VERSIONS, operation_enum, msg)
payload = batch_item.request_payload
if protocol_versions is None:
protocol_versions = list()
msg = base.format(DiscoverVersionsRequestPayload, payload)
self.assertIsInstance(payload, DiscoverVersionsRequestPayload, msg)
observed = payload.protocol_versions
msg = base.format(protocol_versions, observed)
self.assertEqual(protocol_versions, observed, msg)
def test_build_discover_versions_batch_item_with_input(self):
protocol_versions = [ProtocolVersion.create(1, 0)]
self._test_build_discover_versions_batch_item(protocol_versions)
def test_build_discover_versions_batch_item_no_input(self):
protocol_versions = None
self._test_build_discover_versions_batch_item(protocol_versions)
def test_process_batch_items(self):
batch_item = ResponseBatchItem(
operation=Operation(OperationEnum.CREATE_KEY_PAIR),
@ -651,3 +697,29 @@ class TestKMIPClient(TestCase):
msg = "expected {0}, received {1}".format(RekeyKeyPairResult, result)
self.assertIsInstance(result, RekeyKeyPairResult, msg)
def _test_process_discover_versions_batch_item(self, protocol_versions):
batch_item = ResponseBatchItem(
operation=Operation(OperationEnum.DISCOVER_VERSIONS),
response_payload=DiscoverVersionsResponsePayload(
protocol_versions))
result = self.client._process_discover_versions_batch_item(batch_item)
base = "expected {0}, received {1}"
msg = base.format(DiscoverVersionsResult, result)
self.assertIsInstance(result, DiscoverVersionsResult, msg)
# The payload maps protocol_versions to an empty list on None
if protocol_versions is None:
protocol_versions = list()
msg = base.format(protocol_versions, result.protocol_versions)
self.assertEqual(protocol_versions, result.protocol_versions, msg)
def test_process_discover_versions_batch_item_with_results(self):
protocol_versions = [ProtocolVersion.create(1, 0)]
self._test_process_discover_versions_batch_item(protocol_versions)
def test_process_discover_versions_batch_item_no_results(self):
protocol_versions = None
self._test_process_discover_versions_batch_item(protocol_versions)