diff --git a/kmip/core/attributes.py b/kmip/core/attributes.py index 7ecb029..0f87136 100644 --- a/kmip/core/attributes.py +++ b/kmip/core/attributes.py @@ -31,9 +31,22 @@ from enum import Enum # 3.1 class UniqueIdentifier(TextString): + def __init__(self, value=None, tag=Tags.UNIQUE_IDENTIFIER): + super(UniqueIdentifier, self).__init__(value, tag) + + +class PrivateKeyUniqueIdentifier(UniqueIdentifier): + def __init__(self, value=None): - super(self.__class__, self).__init__(value, - Tags.UNIQUE_IDENTIFIER) + super(PrivateKeyUniqueIdentifier, self).__init__( + value, Tags.PRIVATE_KEY_UNIQUE_IDENTIFIER) + + +class PublicKeyUniqueIdentifier(UniqueIdentifier): + + def __init__(self, value=None): + super(PublicKeyUniqueIdentifier, self).__init__( + value, Tags.PUBLIC_KEY_UNIQUE_IDENTIFIER) # 3.2 diff --git a/kmip/core/factories/payloads/request.py b/kmip/core/factories/payloads/request.py index 704f421..a00e473 100644 --- a/kmip/core/factories/payloads/request.py +++ b/kmip/core/factories/payloads/request.py @@ -16,6 +16,7 @@ from kmip.core.factories.payloads import PayloadFactory from kmip.core.messages.payloads import create +from kmip.core.messages.payloads import create_key_pair from kmip.core.messages.payloads import destroy from kmip.core.messages.payloads import get from kmip.core.messages.payloads import locate @@ -27,6 +28,9 @@ class RequestPayloadFactory(PayloadFactory): def _create_create_payload(self): return create.CreateRequestPayload() + def _create_create_key_pair_payload(self): + return create_key_pair.CreateKeyPairRequestPayload() + def _create_register_payload(self): return register.RegisterRequestPayload() diff --git a/kmip/core/factories/payloads/response.py b/kmip/core/factories/payloads/response.py index d9d9974..7d86908 100644 --- a/kmip/core/factories/payloads/response.py +++ b/kmip/core/factories/payloads/response.py @@ -16,6 +16,7 @@ from kmip.core.factories.payloads import PayloadFactory from kmip.core.messages.payloads import create +from kmip.core.messages.payloads import create_key_pair from kmip.core.messages.payloads import destroy from kmip.core.messages.payloads import get from kmip.core.messages.payloads import locate @@ -27,6 +28,9 @@ class ResponsePayloadFactory(PayloadFactory): def _create_create_payload(self): return create.CreateResponsePayload() + def _create_create_key_pair_payload(self): + return create_key_pair.CreateKeyPairResponsePayload() + def _create_register_payload(self): return register.RegisterResponsePayload() diff --git a/kmip/core/messages/payloads/create_key_pair.py b/kmip/core/messages/payloads/create_key_pair.py new file mode 100644 index 0000000..d03c9f8 --- /dev/null +++ b/kmip/core/messages/payloads/create_key_pair.py @@ -0,0 +1,205 @@ +# Copyright (c) 2014 The Johns Hopkins University/Applied Physics Laboratory +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from kmip.core import attributes +from kmip.core import objects + +from kmip.core.enums import Tags + +from kmip.core.primitives import Struct + +from kmip.core.utils import BytearrayStream + + +class CreateKeyPairRequestPayload(Struct): + + def __init__(self, + common_template_attribute=None, + private_key_template_attribute=None, + public_key_template_attribute=None): + super(self.__class__, self).__init__(Tags.REQUEST_PAYLOAD) + + self.common_template_attribute = common_template_attribute + self.private_key_template_attribute = private_key_template_attribute + self.public_key_template_attribute = public_key_template_attribute + + self.validate() + + def read(self, istream): + super(self.__class__, self).read(istream) + tstream = BytearrayStream(istream.read(self.length)) + + if self.is_tag_next(Tags.COMMON_TEMPLATE_ATTRIBUTE, tstream): + self.common_template_attribute = objects.CommonTemplateAttribute() + self.common_template_attribute.read(tstream) + + if self.is_tag_next(Tags.PRIVATE_KEY_TEMPLATE_ATTRIBUTE, tstream): + self.private_key_template_attribute = \ + objects.PrivateKeyTemplateAttribute() + self.private_key_template_attribute.read(tstream) + + if self.is_tag_next(Tags.PUBLIC_KEY_TEMPLATE_ATTRIBUTE, tstream): + self.public_key_template_attribute = \ + objects.PublicKeyTemplateAttribute() + self.public_key_template_attribute.read(tstream) + + self.is_oversized(tstream) + self.validate() + + def write(self, ostream): + tstream = BytearrayStream() + + if self.common_template_attribute is not None: + self.common_template_attribute.write(tstream) + + if self.private_key_template_attribute is not None: + self.private_key_template_attribute.write(tstream) + + if self.public_key_template_attribute is not None: + self.public_key_template_attribute.write(tstream) + + self.length = tstream.length() + super(self.__class__, self).write(ostream) + ostream.write(tstream.buffer) + + def validate(self): + self.__validate() + + def __validate(self): + if self.common_template_attribute is not None: + if not isinstance(self.common_template_attribute, + objects.CommonTemplateAttribute): + msg = "invalid common template attribute" + msg += "; expected {0}, received {1}".format( + objects.CommonTemplateAttribute, + self.common_template_attribute) + raise TypeError(msg) + + if self.private_key_template_attribute is not None: + if not isinstance(self.private_key_template_attribute, + objects.PrivateKeyTemplateAttribute): + msg = "invalid private key template attribute" + msg += "; expected {0}, received {1}".format( + objects.PrivateKeyTemplateAttribute, + self.private_key_template_attribute) + raise TypeError(msg) + + if self.public_key_template_attribute is not None: + if not isinstance(self.public_key_template_attribute, + objects.PublicKeyTemplateAttribute): + msg = "invalid public key template attribute" + msg += "; expected {0}, received {1}".format( + objects.PublicKeyTemplateAttribute, + self.public_key_template_attribute) + raise TypeError(msg) + + +class CreateKeyPairResponsePayload(Struct): + + def __init__(self, + private_key_uuid=None, + public_key_uuid=None, + private_key_template_attribute=None, + public_key_template_attribute=None): + super(self.__class__, self).__init__(Tags.RESPONSE_PAYLOAD) + + if private_key_uuid is None: + self.private_key_uuid = attributes.PrivateKeyUniqueIdentifier('') + else: + self.private_key_uuid = private_key_uuid + + if public_key_uuid is None: + self.public_key_uuid = attributes.PublicKeyUniqueIdentifier('') + else: + self.public_key_uuid = public_key_uuid + + self.private_key_template_attribute = private_key_template_attribute + self.public_key_template_attribute = public_key_template_attribute + + self.validate() + + def read(self, istream): + super(self.__class__, self).read(istream) + tstream = BytearrayStream(istream.read(self.length)) + + self.private_key_uuid.read(tstream) + self.public_key_uuid.read(tstream) + + if self.is_tag_next(Tags.PRIVATE_KEY_TEMPLATE_ATTRIBUTE, tstream): + self.private_key_template_attribute = \ + objects.PrivateKeyTemplateAttribute() + self.private_key_template_attribute.read(tstream) + + if self.is_tag_next(Tags.PUBLIC_KEY_TEMPLATE_ATTRIBUTE, tstream): + self.public_key_template_attribute = \ + objects.PublicKeyTemplateAttribute() + self.public_key_template_attribute.read(tstream) + + self.is_oversized(tstream) + self.validate() + + def write(self, ostream): + tstream = BytearrayStream() + + self.private_key_uuid.write(tstream) + self.public_key_uuid.write(tstream) + + if self.private_key_template_attribute is not None: + self.private_key_template_attribute.write(tstream) + + if self.public_key_template_attribute is not None: + self.public_key_template_attribute.write(tstream) + + self.length = tstream.length() + super(self.__class__, self).write(ostream) + ostream.write(tstream.buffer) + + def validate(self): + self.__validate() + + def __validate(self): + if not isinstance(self.private_key_uuid, + attributes.PrivateKeyUniqueIdentifier): + msg = "invalid private key unique identifier" + msg += "; expected {0}, received {1}".format( + attributes.PrivateKeyUniqueIdentifier, + self.private_key_uuid) + raise TypeError(msg) + + if not isinstance(self.public_key_uuid, + attributes.PublicKeyUniqueIdentifier): + msg = "invalid public key unique identifier" + msg += "; expected {0}, received {1}".format( + attributes.PublicKeyUniqueIdentifier, + self.public_key_uuid) + raise TypeError(msg) + + if self.private_key_template_attribute is not None: + if not isinstance(self.private_key_template_attribute, + objects.PrivateKeyTemplateAttribute): + msg = "invalid private key template attribute" + msg += "; expected {0}, received {1}".format( + objects.PrivateKeyTemplateAttribute, + self.private_key_template_attribute) + raise TypeError(msg) + + if self.public_key_template_attribute is not None: + if not isinstance(self.public_key_template_attribute, + objects.PublicKeyTemplateAttribute): + msg = "invalid public key template attribute" + msg += "; expected {0}, received {1}".format( + objects.PublicKeyTemplateAttribute, + self.public_key_template_attribute) + raise TypeError(msg) diff --git a/kmip/core/objects.py b/kmip/core/objects.py index 662b34c..5d270c6 100644 --- a/kmip/core/objects.py +++ b/kmip/core/objects.py @@ -13,6 +13,8 @@ # License for the specific language governing permissions and limitations # under the License. +from six.moves import xrange + from kmip.core import attributes from kmip.core.attributes import CryptographicParameters @@ -109,6 +111,22 @@ class Attribute(Struct): super(self.__class__, self).write(ostream) ostream.write(tstream.buffer) + def __eq__(self, other): + if isinstance(other, Attribute): + if self.attribute_name != other.attribute_name: + return False + elif self.attribute_index != other.attribute_index: + return False + elif self.attribute_value != other.attribute_value: + return False + else: + return True + else: + return NotImplemented + + def __ne__(self, other): + return not self.__eq__(other) + # 2.1.2 class Credential(Struct): @@ -793,14 +811,24 @@ class TemplateAttribute(Struct): def __init__(self, names=None, - attributes=None): - super(self.__class__, self).__init__(tag=Tags.TEMPLATE_ATTRIBUTE) - self.names = names - self.attributes = attributes + attributes=None, + tag=Tags.TEMPLATE_ATTRIBUTE): + super(TemplateAttribute, self).__init__(tag) + + if names is None: + self.names = list() + else: + self.names = names + + if attributes is None: + self.attributes = list() + else: + self.attributes = attributes + self.validate() def read(self, istream): - super(self.__class__, self).read(istream) + super(TemplateAttribute, self).read(istream) tstream = BytearrayStream(istream.read(self.length)) self.names = list() @@ -825,16 +853,14 @@ class TemplateAttribute(Struct): tstream = BytearrayStream() # Write the names and attributes of the template attribute - if self.names is not None: - for name in self.names: - name.write(tstream) - if self.attributes is not None: - for attribute in self.attributes: - attribute.write(tstream) + for name in self.names: + name.write(tstream) + for attribute in self.attributes: + attribute.write(tstream) # Write the length and value of the template attribute self.length = tstream.length() - super(self.__class__, self).write(ostream) + super(TemplateAttribute, self).write(ostream) ostream.write(tstream.buffer) def validate(self): @@ -843,3 +869,55 @@ class TemplateAttribute(Struct): def __validate(self): # TODO (peter-hamilton) Finish implementation. pass + + def __eq__(self, other): + if isinstance(other, TemplateAttribute): + if len(self.names) != len(other.names): + return False + if len(self.attributes) != len(other.attributes): + return False + + for i in xrange(len(self.names)): + a = self.names[i] + b = other.names[i] + + if a != b: + return False + + for i in xrange(len(self.attributes)): + a = self.attributes[i] + b = other.attributes[i] + + if a != b: + return False + + return True + else: + return NotImplemented + + +class CommonTemplateAttribute(TemplateAttribute): + + def __init__(self, + names=None, + attributes=None): + super(CommonTemplateAttribute, self).__init__( + names, attributes, Tags.COMMON_TEMPLATE_ATTRIBUTE) + + +class PrivateKeyTemplateAttribute(TemplateAttribute): + + def __init__(self, + names=None, + attributes=None): + super(PrivateKeyTemplateAttribute, self).__init__( + names, attributes, Tags.PRIVATE_KEY_TEMPLATE_ATTRIBUTE) + + +class PublicKeyTemplateAttribute(TemplateAttribute): + + def __init__(self, + names=None, + attributes=None): + super(PublicKeyTemplateAttribute, self).__init__( + names, attributes, Tags.PUBLIC_KEY_TEMPLATE_ATTRIBUTE) diff --git a/kmip/core/primitives.py b/kmip/core/primitives.py index 221df47..5901155 100644 --- a/kmip/core/primitives.py +++ b/kmip/core/primitives.py @@ -541,6 +541,15 @@ class TextString(Base): def __repr__(self): return '' % (self.value) + def __eq__(self, other): + if isinstance(other, TextString): + return self.value == other.value + else: + return NotImplemented + + def __ne__(self, other): + return not self.__eq__(other) + class ByteString(Base): PADDING_SIZE = 8 diff --git a/kmip/core/server.py b/kmip/core/server.py index e45d4e2..1b92c72 100644 --- a/kmip/core/server.py +++ b/kmip/core/server.py @@ -56,23 +56,28 @@ class KMIP(object): pass def create(self, object_type, template_attribute, credential=None): - raise NotImplementedError + raise NotImplementedError() + + def create_key_pair(self, common_template_attribute, + private_key_template_attribute, + public_key_template_attribute): + raise NotImplementedError() def register(self, object_type, template_attribute, secret, credential=None): - raise NotImplementedError + raise NotImplementedError() def get(self, uuid=None, key_format_type=None, key_compression_type=None, key_wrapping_specification=None, credential=None): - raise NotImplementedError + raise NotImplementedError() def destroy(self, uuid, credential=None): - raise NotImplementedError + raise NotImplementedError() def locate(self, maximum_items=None, storate_status_mask=None, object_group_member=None, attributes=None, credential=None): - raise NotImplementedError + raise NotImplementedError() class KMIPImpl(KMIP): @@ -129,11 +134,15 @@ class KMIPImpl(KMIP): s_uuid, uuid_attribute = self._save(key, attributes) ret_attributes.append(uuid_attribute) template_attribute = TemplateAttribute(attributes=ret_attributes) - return CreateResult(ResultStatus(RS.SUCCESS), - object_type=object_type, + return CreateResult(ResultStatus(RS.SUCCESS), object_type=object_type, uuid=UniqueIdentifier(s_uuid), template_attribute=template_attribute) + def create_key_pair(self, common_template_attribute, + private_key_template_attribute, + public_key_template_attribute): + raise NotImplementedError() + def register(self, object_type, template_attribute, secret, credential=None): self.logger.debug('register() called') @@ -229,10 +238,8 @@ class KMIPImpl(KMIP): # currently only symmetric keys are supported, fix this in future object_type = ObjectType(OT.SYMMETRIC_KEY) ret_value = RS.SUCCESS - return GetResult(ResultStatus(ret_value), - object_type=object_type, - uuid=uuid, - secret=managed_object) + return GetResult(ResultStatus(ret_value), object_type=object_type, + uuid=uuid, secret=managed_object) def destroy(self, uuid): self.logger.debug('destroy() called') @@ -268,8 +275,7 @@ class KMIPImpl(KMIP): msg = ResultMessage('Locate Operation Not Supported') reason = ResultReason(ResultReasonEnum.OPERATION_NOT_SUPPORTED) return LocateResult(ResultStatus(RS.OPERATION_FAILED), - result_reason=reason, - result_message=msg) + result_reason=reason, result_message=msg) def _validate_req_field(self, attrs, name, expected, msg, required=True): self.logger.debug('Validating attribute %s' % name) diff --git a/kmip/core/utils.py b/kmip/core/utils.py index 79c1230..adb4ca2 100644 --- a/kmip/core/utils.py +++ b/kmip/core/utils.py @@ -107,3 +107,16 @@ class BytearrayStream(io.RawIOBase): def length(self): return len(self.buffer) + + def __str__(self): + sbuffer = bytes(self.buffer[0:]) + return str(hexlify(sbuffer)) + + def __len__(self): + return len(self.buffer) + + def __eq__(self, other): + if isinstance(other, BytearrayStream): + return (self.buffer == other.buffer) + else: + return NotImplemented diff --git a/kmip/services/kmip_client.py b/kmip/services/kmip_client.py index a811158..4283169 100644 --- a/kmip/services/kmip_client.py +++ b/kmip/services/kmip_client.py @@ -14,6 +14,7 @@ # under the License. from kmip.services.results import CreateResult +from kmip.services.results import CreateKeyPairResult from kmip.services.results import GetResult from kmip.services.results import DestroyResult from kmip.services.results import RegisterResult @@ -37,6 +38,7 @@ from kmip.core.messages.contents import Operation from kmip.core.messages import messages from kmip.core.messages.payloads import create +from kmip.core.messages.payloads import create_key_pair from kmip.core.messages.payloads import get from kmip.core.messages.payloads import register from kmip.core.messages.payloads import locate @@ -75,6 +77,9 @@ class KMIPProxy(KMIP): do_handshake_on_connect, suppress_ragged_eofs, username, password) + self.batch_items = [] + + def open(self): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket = ssl.wrap_socket( sock, @@ -87,7 +92,6 @@ class KMIPProxy(KMIP): suppress_ragged_eofs=self.suppress_ragged_eofs) self.protocol = KMIPProtocol(self.socket) - def open(self): self.socket.connect((self.host, self.port)) def close(self): @@ -99,6 +103,21 @@ class KMIPProxy(KMIP): template_attribute=template_attribute, credential=credential) + def create_key_pair(self, batch=False, common_template_attribute=None, + private_key_template_attribute=None, + public_key_template_attribute=None, credential=None): + batch_item = self._build_create_key_pair_batch_item( + common_template_attribute, private_key_template_attribute, + public_key_template_attribute) + + if batch: + self.batch_items.append(batch_item) + else: + request = self._build_request_message(credential, [batch_item]) + response = self._send_and_receive_message(request) + results = self._process_batch_items(response) + return results[0] + def get(self, uuid=None, key_format_type=None, key_compression_type=None, key_wrapping_specification=None, credential=None): return self._get(unique_identifier=uuid, credential=credential) @@ -163,6 +182,58 @@ class KMIPProxy(KMIP): payload_template_attribute) return result + def _build_create_key_pair_batch_item(self, common_template_attribute=None, + private_key_template_attribute=None, + public_key_template_attribute=None): + operation = Operation(OperationEnum.CREATE_KEY_PAIR) + payload = create_key_pair.CreateKeyPairRequestPayload( + common_template_attribute=common_template_attribute, + private_key_template_attribute=private_key_template_attribute, + public_key_template_attribute=public_key_template_attribute) + batch_item = messages.RequestBatchItem( + operation=operation, request_payload=payload) + return batch_item + + def _process_batch_items(self, response): + results = [] + for batch_item in response.batch_items: + operation = batch_item.operation.enum + processor = self._get_batch_item_processor(operation) + result = processor(batch_item) + results.append(result) + return results + + def _get_batch_item_processor(self, operation): + if operation == OperationEnum.CREATE_KEY_PAIR: + return self._process_create_key_pair_batch_item + else: + raise ValueError("no processor for given operation") + + def _process_create_key_pair_batch_item(self, batch_item): + payload = batch_item.response_payload + + payload_private_key_uuid = None + payload_public_key_uuid = None + payload_private_key_template_attribute = None + payload_public_key_template_attribute = None + + if payload is not None: + payload_private_key_uuid = payload.private_key_uuid + payload_public_key_uuid = payload.public_key_uuid + payload_private_key_template_attribute = \ + payload.private_key_template_attribute + payload_public_key_template_attribute = \ + payload.public_key_template_attribute + + result = CreateKeyPairResult(batch_item.result_status, + batch_item.result_reason, + batch_item.result_message, + payload_private_key_uuid, + payload_public_key_uuid, + payload_private_key_template_attribute, + payload_public_key_template_attribute) + return result + def _get(self, unique_identifier=None, key_format_type=None, @@ -382,6 +453,13 @@ class KMIPProxy(KMIP): def _receive_message(self): return self.protocol.read() + def _send_and_receive_message(self, request): + self._send_message(request) + response = messages.ResponseMessage() + data = self._receive_message() + response.read(data) + return response + def _set_variables(self, host, port, keyfile, certfile, cert_reqs, ssl_version, ca_certs, do_handshake_on_connect, suppress_ragged_eofs, diff --git a/kmip/services/results.py b/kmip/services/results.py index e77d836..a0a9a76 100644 --- a/kmip/services/results.py +++ b/kmip/services/results.py @@ -61,6 +61,22 @@ class CreateResult(OperationResult): self.template_attribute = None +class CreateKeyPairResult(OperationResult): + + def __init__(self, + result_status, + result_reason=None, + result_message=None, + private_key_uuid=None, + public_key_uuid=None, + private_key_template_attribute=None, + public_key_template_attribute=None): + self.private_key_uuid = private_key_uuid + self.public_key_uuid = public_key_uuid + self.private_key_template_attribute = private_key_template_attribute + self.public_key_template_attribute = public_key_template_attribute + + class RegisterResult(OperationResult): def __init__(self, diff --git a/kmip/tests/core/factories/payloads/test_request.py b/kmip/tests/core/factories/payloads/test_request.py index 9872c6f..05b002c 100644 --- a/kmip/tests/core/factories/payloads/test_request.py +++ b/kmip/tests/core/factories/payloads/test_request.py @@ -19,20 +19,21 @@ from kmip.core.enums import Operation from kmip.core.factories.payloads.request import RequestPayloadFactory from kmip.core.messages.payloads import create +from kmip.core.messages.payloads import create_key_pair from kmip.core.messages.payloads import destroy from kmip.core.messages.payloads import get from kmip.core.messages.payloads import locate from kmip.core.messages.payloads import register -class TestPayloadFactory(testtools.TestCase): +class TestRequestPayloadFactory(testtools.TestCase): def setUp(self): - super(TestPayloadFactory, self).setUp() + super(TestRequestPayloadFactory, self).setUp() self.factory = RequestPayloadFactory() def tearDown(self): - super(TestPayloadFactory, self).tearDown() + super(TestRequestPayloadFactory, self).tearDown() def _test_not_implemented(self, func, args): self.assertRaises(NotImplementedError, func, args) @@ -46,8 +47,9 @@ class TestPayloadFactory(testtools.TestCase): self._test_payload_type(payload, create.CreateRequestPayload) def test_create_create_key_pair_payload(self): - self._test_not_implemented( - self.factory.create, Operation.CREATE_KEY_PAIR) + payload = self.factory.create(Operation.CREATE_KEY_PAIR) + self._test_payload_type( + payload, create_key_pair.CreateKeyPairRequestPayload) def test_create_register_payload(self): payload = self.factory.create(Operation.REGISTER) diff --git a/kmip/tests/core/factories/payloads/test_response.py b/kmip/tests/core/factories/payloads/test_response.py index 6a526fb..cf33641 100644 --- a/kmip/tests/core/factories/payloads/test_response.py +++ b/kmip/tests/core/factories/payloads/test_response.py @@ -19,20 +19,21 @@ from kmip.core.enums import Operation from kmip.core.factories.payloads.response import ResponsePayloadFactory from kmip.core.messages.payloads import create +from kmip.core.messages.payloads import create_key_pair from kmip.core.messages.payloads import destroy from kmip.core.messages.payloads import get from kmip.core.messages.payloads import locate from kmip.core.messages.payloads import register -class TestPayloadFactory(testtools.TestCase): +class TestResponsePayloadFactory(testtools.TestCase): def setUp(self): - super(TestPayloadFactory, self).setUp() + super(TestResponsePayloadFactory, self).setUp() self.factory = ResponsePayloadFactory() def tearDown(self): - super(TestPayloadFactory, self).tearDown() + super(TestResponsePayloadFactory, self).tearDown() def _test_not_implemented(self, func, args): self.assertRaises(NotImplementedError, func, args) @@ -46,8 +47,9 @@ class TestPayloadFactory(testtools.TestCase): self._test_payload_type(payload, create.CreateResponsePayload) def test_create_create_key_pair_payload(self): - self._test_not_implemented( - self.factory.create, Operation.CREATE_KEY_PAIR) + payload = self.factory.create(Operation.CREATE_KEY_PAIR) + self._test_payload_type( + payload, create_key_pair.CreateKeyPairResponsePayload) def test_create_register_payload(self): payload = self.factory.create(Operation.REGISTER) diff --git a/kmip/tests/core/messages/payloads/test_create_key_pair.py b/kmip/tests/core/messages/payloads/test_create_key_pair.py new file mode 100644 index 0000000..4269007 --- /dev/null +++ b/kmip/tests/core/messages/payloads/test_create_key_pair.py @@ -0,0 +1,306 @@ +# Copyright (c) 2014 The Johns Hopkins University/Applied Physics Laboratory +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from testtools import TestCase + +from kmip.core import attributes +from kmip.core import objects +from kmip.core import utils + +from kmip.core.messages.payloads import create_key_pair + + +class TestCreateKeyPairRequestPayload(TestCase): + + def setUp(self): + super(TestCreateKeyPairRequestPayload, self).setUp() + + self.common_template_attribute = objects.CommonTemplateAttribute() + self.private_key_template_attribute = \ + objects.PrivateKeyTemplateAttribute() + self.public_key_template_attribute = \ + objects.PublicKeyTemplateAttribute() + + self.encoding_empty = utils.BytearrayStream(( + b'\x42\x00\x79\x01\x00\x00\x00\x00')) + self.encoding_full = utils.BytearrayStream(( + b'\x42\x00\x79\x01\x00\x00\x00\x18\x42\x00\x1F\x01\x00\x00\x00\x00' + b'\x42\x00\x65\x01\x00\x00\x00\x00\x42\x00\x6E\x01\x00\x00\x00' + b'\x00')) + + def tearDown(self): + super(TestCreateKeyPairRequestPayload, self).tearDown() + + def test_init_with_none(self): + create_key_pair.CreateKeyPairRequestPayload() + + def test_init_with_args(self): + create_key_pair.CreateKeyPairRequestPayload( + self.common_template_attribute, + self.private_key_template_attribute, + self.public_key_template_attribute) + + def test_validate_with_invalid_common_template_attribute(self): + kwargs = {'common_template_attribute': 'invalid', + 'private_key_template_attribute': None, + 'public_key_template_attribute': None} + self.assertRaisesRegexp( + TypeError, "invalid common template attribute", + create_key_pair.CreateKeyPairRequestPayload, **kwargs) + + def test_validate_with_invalid_private_key_template_attribute(self): + kwargs = {'common_template_attribute': None, + 'private_key_template_attribute': 'invalid', + 'public_key_template_attribute': None} + self.assertRaisesRegexp( + TypeError, "invalid private key template attribute", + create_key_pair.CreateKeyPairRequestPayload, **kwargs) + + def test_validate_with_invalid_public_key_template_attribute(self): + kwargs = {'common_template_attribute': None, + 'private_key_template_attribute': None, + 'public_key_template_attribute': 'invalid'} + self.assertRaises( + TypeError, "invalid public key template attribute", + create_key_pair.CreateKeyPairRequestPayload, **kwargs) + + def _test_read(self, stream, payload, common_template_attribute, + private_key_template_attribute, + public_key_template_attribute): + payload.read(stream) + + msg = "common_template_attribute decoding mismatch" + msg += "; expected {0}, received {1}".format( + common_template_attribute, payload.common_template_attribute) + self.assertEqual(common_template_attribute, + payload.common_template_attribute, msg) + + msg = "private_key_template_attribute decoding mismatch" + msg += "; expected {0}, received {1}".format( + private_key_template_attribute, + payload.private_key_template_attribute) + self.assertEqual(private_key_template_attribute, + payload.private_key_template_attribute, msg) + + msg = "public_key_template_attribute decoding mismatch" + msg += "; expected {0}, received {1}".format( + public_key_template_attribute, + payload.public_key_template_attribute) + self.assertEqual(public_key_template_attribute, + payload.public_key_template_attribute, msg) + + def test_read_with_none(self): + stream = self.encoding_empty + payload = create_key_pair.CreateKeyPairRequestPayload() + + self._test_read(stream, payload, None, None, None) + + def test_read_with_args(self): + stream = self.encoding_full + payload = create_key_pair.CreateKeyPairRequestPayload() + + self._test_read(stream, payload, self.common_template_attribute, + self.private_key_template_attribute, + self.public_key_template_attribute) + + def _test_write(self, stream, payload, expected): + payload.write(stream) + + length_expected = len(expected) + length_received = len(stream) + + msg = "encoding lengths not equal" + msg += "; expected {0}, received {1}".format( + length_expected, length_received) + self.assertEqual(length_expected, length_received, msg) + + msg = "encoding mismatch" + msg += ";\nexpected:\n{0}\nreceived:\n{1}".format(expected, stream) + + self.assertEqual(expected, stream, msg) + + def test_write_with_none(self): + stream = utils.BytearrayStream() + payload = create_key_pair.CreateKeyPairRequestPayload() + + self._test_write(stream, payload, self.encoding_empty) + + def test_write_with_args(self): + stream = utils.BytearrayStream() + payload = create_key_pair.CreateKeyPairRequestPayload( + self.common_template_attribute, + self.private_key_template_attribute, + self.public_key_template_attribute) + + self._test_write(stream, payload, self.encoding_full) + + +class TestCreateKeyPairResponsePayload(TestCase): + + def setUp(self): + super(TestCreateKeyPairResponsePayload, self).setUp() + + self.uuid = '00000000-0000-0000-0000-000000000000' + self.private_key_uuid = attributes.PrivateKeyUniqueIdentifier( + self.uuid) + self.public_key_uuid = attributes.PublicKeyUniqueIdentifier( + self.uuid) + self.empty_private_key_uuid = attributes.PrivateKeyUniqueIdentifier('') + self.empty_public_key_uuid = attributes.PublicKeyUniqueIdentifier('') + + self.private_key_template_attribute = \ + objects.PrivateKeyTemplateAttribute() + self.public_key_template_attribute = \ + objects.PublicKeyTemplateAttribute() + + self.encoding_empty = utils.BytearrayStream(( + b'\x42\x00\x7C\x01\x00\x00\x00\x10\x42\x00\x66\x07\x00\x00\x00\x00' + b'\x42\x00\x6F\x07\x00\x00\x00\x00')) + self.encoding_full = utils.BytearrayStream(( + b'\x42\x00\x7C\x01\x00\x00\x00\x70\x42\x00\x66\x07\x00\x00\x00\x24' + b'\x30\x30\x30\x30\x30\x30\x30\x30\x2d\x30\x30\x30\x30\x2d\x30\x30' + b'\x30\x30\x2d\x30\x30\x30\x30\x2d\x30\x30\x30\x30\x30\x30\x30\x30' + b'\x30\x30\x30\x30\x00\x00\x00\x00\x42\x00\x6F\x07\x00\x00\x00\x24' + b'\x30\x30\x30\x30\x30\x30\x30\x30\x2d\x30\x30\x30\x30\x2d\x30\x30' + b'\x30\x30\x2d\x30\x30\x30\x30\x2d\x30\x30\x30\x30\x30\x30\x30\x30' + b'\x30\x30\x30\x30\x00\x00\x00\x00\x42\x00\x65\x01\x00\x00\x00\x00' + b'\x42\x00\x6E\x01\x00\x00\x00\x00')) + + def tearDown(self): + super(TestCreateKeyPairResponsePayload, self).tearDown() + + def test_init_with_none(self): + create_key_pair.CreateKeyPairResponsePayload() + + def test_init_with_args(self): + create_key_pair.CreateKeyPairResponsePayload( + self.private_key_uuid, self.public_key_uuid, + self.private_key_template_attribute, + self.public_key_template_attribute) + + def test_validate_with_invalid_private_key_unique_identifier(self): + kwargs = {'private_key_uuid': 'invalid', + 'public_key_uuid': self.public_key_uuid, + 'private_key_template_attribute': None, + 'public_key_template_attribute': None} + self.assertRaisesRegexp( + TypeError, "invalid private key unique identifier", + create_key_pair.CreateKeyPairResponsePayload, **kwargs) + + def test_validate_with_invalid_public_key_unique_identifier(self): + kwargs = {'private_key_uuid': self.private_key_uuid, + 'public_key_uuid': 'invalid', + 'private_key_template_attribute': None, + 'public_key_template_attribute': None} + self.assertRaisesRegexp( + TypeError, "invalid public key unique identifier", + create_key_pair.CreateKeyPairResponsePayload, **kwargs) + + def test_validate_with_invalid_private_key_template_attribute(self): + kwargs = {'private_key_uuid': self.private_key_uuid, + 'public_key_uuid': self.public_key_uuid, + 'private_key_template_attribute': 'invalid', + 'public_key_template_attribute': None} + self.assertRaisesRegexp( + TypeError, "invalid private key template attribute", + create_key_pair.CreateKeyPairResponsePayload, **kwargs) + + def test_validate_with_invalid_public_key_template_attribute(self): + kwargs = {'private_key_uuid': self.private_key_uuid, + 'public_key_uuid': self.public_key_uuid, + 'private_key_template_attribute': None, + 'public_key_template_attribute': 'invalid'} + self.assertRaisesRegexp( + TypeError, "invalid public key template attribute", + create_key_pair.CreateKeyPairResponsePayload, **kwargs) + + def _test_read(self, stream, payload, private_key_uuid, public_key_uuid, + private_key_template_attribute, + public_key_template_attribute): + payload.read(stream) + + msg = "private_key_uuid decoding mismatch" + msg += "; expected {0}, received {1}".format( + private_key_uuid, payload.private_key_uuid) + self.assertEqual(private_key_uuid, payload.private_key_uuid, msg) + + msg = "public_key_uuid decoding mismatch" + msg += "; expected {0}, received {1}".format( + public_key_uuid, payload.public_key_uuid) + self.assertEqual(public_key_uuid, payload.public_key_uuid, msg) + + msg = "private_key_template_attribute decoding mismatch" + msg += "; expected {0}, received {1}".format( + private_key_template_attribute, + payload.private_key_template_attribute) + self.assertEqual(private_key_template_attribute, + payload.private_key_template_attribute, msg) + + msg = "public_key_template_attribute decoding mismatch" + msg += "; expected {0}, received {1}".format( + public_key_template_attribute, + payload.public_key_template_attribute) + self.assertEqual(public_key_template_attribute, + payload.public_key_template_attribute, msg) + + def test_read_with_none(self): + stream = self.encoding_empty + payload = create_key_pair.CreateKeyPairResponsePayload() + + self._test_read(stream, payload, self.empty_private_key_uuid, + self.empty_public_key_uuid, None, None) + + def test_read_with_args(self): + stream = self.encoding_full + payload = create_key_pair.CreateKeyPairResponsePayload( + self.private_key_uuid, self.public_key_uuid, + self.private_key_template_attribute, + self.public_key_template_attribute) + + self._test_read(stream, payload, self.private_key_uuid, + self.public_key_uuid, + self.private_key_template_attribute, + self.public_key_template_attribute) + + def _test_write(self, stream, payload, expected): + payload.write(stream) + + length_expected = len(expected) + length_received = len(stream) + + msg = "encoding lengths not equal" + msg += "; expected {0}, received {1}".format( + length_expected, length_received) + self.assertEqual(length_expected, length_received, msg) + + msg = "encoding mismatch" + msg += ";\nexpected:\n{0}\nreceived:\n{1}".format(expected, stream) + + self.assertEqual(expected, stream, msg) + + def test_write_with_none(self): + stream = utils.BytearrayStream() + payload = create_key_pair.CreateKeyPairResponsePayload() + + self._test_write(stream, payload, self.encoding_empty) + + def test_write_with_args(self): + stream = utils.BytearrayStream() + payload = create_key_pair.CreateKeyPairResponsePayload( + self.private_key_uuid, self.public_key_uuid, + self.private_key_template_attribute, + self.public_key_template_attribute) + + self._test_write(stream, payload, self.encoding_full) diff --git a/kmip/tests/services/test_kmip_client.py b/kmip/tests/services/test_kmip_client.py index a8c9b47..68bcaad 100644 --- a/kmip/tests/services/test_kmip_client.py +++ b/kmip/tests/services/test_kmip_client.py @@ -26,6 +26,7 @@ from kmip.core.enums import CredentialType from kmip.core.enums import CryptographicAlgorithm from kmip.core.enums import CryptographicUsageMask from kmip.core.enums import ObjectType +from kmip.core.enums import Operation as OperationEnum from kmip.core.enums import KeyFormatType from kmip.core.enums import ResultStatus from kmip.core.enums import ResultReason @@ -37,17 +38,28 @@ from kmip.core.factories.attributes import AttributeFactory from kmip.core.factories.credentials import CredentialFactory from kmip.core.factories.secrets import SecretFactory +from kmip.core.messages.messages import RequestBatchItem +from kmip.core.messages.messages import ResponseBatchItem +from kmip.core.messages.messages import ResponseMessage +from kmip.core.messages.contents import Operation +from kmip.core.messages.payloads.create_key_pair import \ + CreateKeyPairRequestPayload, CreateKeyPairResponsePayload + from kmip.core.objects import Attribute +from kmip.core.objects import CommonTemplateAttribute +from kmip.core.objects import PrivateKeyTemplateAttribute +from kmip.core.objects import PublicKeyTemplateAttribute from kmip.core.objects import TemplateAttribute from kmip.core.secrets import SymmetricKey from kmip.services.kmip_client import KMIPProxy +from kmip.services.results import CreateKeyPairResult import kmip.core.utils as utils -class TestKMIPClient(TestCase): +class TestKMIPClientIntegration(TestCase): STARTUP_TIME = 1.0 SHUTDOWN_TIME = 0.1 KMIP_PORT = 9090 @@ -55,7 +67,7 @@ class TestKMIPClient(TestCase): os.path.abspath(__file__)), '../../demos/certs/server.crt')) def setUp(self): - super(TestKMIPClient, self).setUp() + super(TestKMIPClientIntegration, self).setUp() self.attr_factory = AttributeFactory() self.cred_factory = CredentialFactory() @@ -82,7 +94,7 @@ class TestKMIPClient(TestCase): raise e def tearDown(self): - super(TestKMIPClient, self).tearDown() + super(TestKMIPClientIntegration, self).tearDown() # Close the client proxy and shutdown the server self.client.close() @@ -248,66 +260,6 @@ class TestKMIPClient(TestCase): expected, observed, 'value') self.assertEqual(expected, observed, message) - # TODO (peter-hamilton) Modify for credential type and/or add new test - def test_build_credential(self): - username = 'username' - password = 'password' - cred_type = CredentialType.USERNAME_AND_PASSWORD - self.client.username = username - self.client.password = password - - credential = self.client._build_credential() - - message = utils.build_er_error(credential.__class__, 'type', - cred_type, - credential.credential_type.enum, - 'value') - self.assertEqual(CredentialType.USERNAME_AND_PASSWORD, - credential.credential_type.enum, - message) - - message = utils.build_er_error( - credential.__class__, 'type', username, - credential.credential_value.username.value, 'value') - self.assertEqual(username, credential.credential_value.username.value, - message) - - message = utils.build_er_error( - credential.__class__, 'type', password, - credential.credential_value.password.value, 'value') - self.assertEqual(password, credential.credential_value.password.value, - message) - - def test_build_credential_no_username(self): - username = None - password = 'password' - self.client.username = username - self.client.password = password - - exception = self.assertRaises(ValueError, - self.client._build_credential) - self.assertEqual('cannot build credential, username is None', - str(exception)) - - def test_build_credential_no_password(self): - username = 'username' - password = None - self.client.username = username - self.client.password = password - - exception = self.assertRaises(ValueError, - self.client._build_credential) - self.assertEqual('cannot build credential, password is None', - str(exception)) - - def test_build_credential_no_creds(self): - self.client.username = None - self.client.password = None - - credential = self.client._build_credential() - - self.assertEqual(None, credential) - def _shutdown_server(self): if self.server.poll() is not None: return @@ -445,3 +397,163 @@ class TestKMIPClient(TestCase): message = utils.build_er_error(Attribute, 'value', expected, observed, 'attribute_value') self.assertEqual(expected, observed, message) + + +class TestKMIPClient(TestCase): + + def setUp(self): + super(TestKMIPClient, self).setUp() + + self.attr_factory = AttributeFactory() + self.cred_factory = CredentialFactory() + self.secret_factory = SecretFactory() + + self.client = KMIPProxy() + + def tearDown(self): + super(TestKMIPClient, self).tearDown() + + # TODO (peter-hamilton) Modify for credential type and/or add new test + def test_build_credential(self): + username = 'username' + password = 'password' + cred_type = CredentialType.USERNAME_AND_PASSWORD + self.client.username = username + self.client.password = password + + credential = self.client._build_credential() + + message = utils.build_er_error(credential.__class__, 'type', + cred_type, + credential.credential_type.enum, + 'value') + self.assertEqual(CredentialType.USERNAME_AND_PASSWORD, + credential.credential_type.enum, + message) + + message = utils.build_er_error( + credential.__class__, 'type', username, + credential.credential_value.username.value, 'value') + self.assertEqual(username, credential.credential_value.username.value, + message) + + message = utils.build_er_error( + credential.__class__, 'type', password, + credential.credential_value.password.value, 'value') + self.assertEqual(password, credential.credential_value.password.value, + message) + + def test_build_credential_no_username(self): + username = None + password = 'password' + self.client.username = username + self.client.password = password + + exception = self.assertRaises(ValueError, + self.client._build_credential) + self.assertEqual('cannot build credential, username is None', + str(exception)) + + def test_build_credential_no_password(self): + username = 'username' + password = None + self.client.username = username + self.client.password = password + + exception = self.assertRaises(ValueError, + self.client._build_credential) + self.assertEqual('cannot build credential, password is None', + str(exception)) + + def test_build_credential_no_creds(self): + self.client.username = None + self.client.password = None + + credential = self.client._build_credential() + + self.assertEqual(None, credential) + + def _test_build_create_key_pair_batch_item(self, common, private, public): + batch_item = self.client._build_create_key_pair_batch_item( + common_template_attribute=common, + private_key_template_attribute=private, + public_key_template_attribute=public) + + base = "expected {0}, received {1}" + msg = base.format(RequestBatchItem, batch_item) + self.assertIsInstance(batch_item, RequestBatchItem, msg) + + operation = batch_item.operation + + msg = base.format(Operation, operation) + self.assertIsInstance(operation, Operation, msg) + + operation_enum = operation.enum + + msg = base.format(OperationEnum.CREATE_KEY_PAIR, operation_enum) + self.assertEqual(OperationEnum.CREATE_KEY_PAIR, operation_enum, msg) + + payload = batch_item.request_payload + + msg = base.format(CreateKeyPairRequestPayload, payload) + self.assertIsInstance(payload, CreateKeyPairRequestPayload, msg) + + common_observed = payload.common_template_attribute + private_observed = payload.private_key_template_attribute + public_observed = payload.public_key_template_attribute + + msg = base.format(common, common_observed) + self.assertEqual(common, common_observed, msg) + + msg = base.format(private, private_observed) + self.assertEqual(private, private_observed, msg) + + msg = base.format(public, public_observed) + self.assertEqual(public, public_observed) + + def test_build_create_key_pair_batch_item_with_input(self): + self._test_build_create_key_pair_batch_item( + CommonTemplateAttribute(), + PrivateKeyTemplateAttribute(), + PublicKeyTemplateAttribute()) + + def test_build_create_key_pair_batch_item_no_input(self): + self._test_build_create_key_pair_batch_item(None, None, None) + + def test_process_batch_items(self): + batch_item = ResponseBatchItem( + operation=Operation(OperationEnum.CREATE_KEY_PAIR), + response_payload=CreateKeyPairResponsePayload()) + response = ResponseMessage(batch_items=[batch_item, batch_item]) + results = self.client._process_batch_items(response) + + base = "expected {0}, received {1}" + msg = base.format(list, results) + self.assertIsInstance(results, list, msg) + + msg = "number of results " + base.format(2, len(results)) + self.assertEqual(2, len(results), msg) + + for result in results: + msg = base.format(CreateKeyPairResult, result) + self.assertIsInstance(result, CreateKeyPairResult, msg) + + def test_process_batch_items_no_batch_items(self): + response = ResponseMessage(batch_items=[]) + results = self.client._process_batch_items(response) + + base = "expected {0}, received {1}" + msg = base.format(list, results) + self.assertIsInstance(results, list, msg) + + msg = "number of results " + base.format(0, len(results)) + self.assertEqual(0, len(results), msg) + + def test_process_create_key_pair_batch_item(self): + batch_item = ResponseBatchItem( + operation=Operation(OperationEnum.CREATE_KEY_PAIR), + response_payload=CreateKeyPairResponsePayload()) + result = self.client._process_create_key_pair_batch_item(batch_item) + + msg = "expected {0}, received {1}".format(CreateKeyPairResult, result) + self.assertIsInstance(result, CreateKeyPairResult, msg)