diff --git a/kmip/core/messages/payloads/create.py b/kmip/core/messages/payloads/create.py index c4ec413..e9169f7 100644 --- a/kmip/core/messages/payloads/create.py +++ b/kmip/core/messages/payloads/create.py @@ -29,11 +29,15 @@ class CreateRequestPayload(primitives.Struct): Attributes: object_type: The type of the object to create. template_attribute: A group of attributes to set on the new object. + protection_storage_masks: An integer representing all of the + protection storage mask selections for the new object. Added in + KMIP 2.0. """ def __init__(self, object_type=None, - template_attribute=None): + template_attribute=None, + protection_storage_masks=None): """ Construct a Create request payload structure. @@ -44,6 +48,9 @@ class CreateRequestPayload(primitives.Struct): template_attribute (TemplateAttribute): A TemplateAttribute structure containing a set of attributes to set on the new object. Optional, defaults to None. Required for read/write. + protection_storage_masks (int): An integer representing all of + the protection storage mask selections for the new object. + Optional, defaults to None. Added in KMIP 2.0. """ super(CreateRequestPayload, self).__init__( tag=enums.Tags.REQUEST_PAYLOAD @@ -51,9 +58,11 @@ class CreateRequestPayload(primitives.Struct): self._object_type = None self._template_attribute = None + self._protection_storage_masks = None self.object_type = object_type self.template_attribute = template_attribute + self.protection_storage_masks = protection_storage_masks @property def object_type(self): @@ -92,6 +101,26 @@ class CreateRequestPayload(primitives.Struct): "Template attribute must be a TemplateAttribute structure." ) + @property + def protection_storage_masks(self): + if self._protection_storage_masks: + return self._protection_storage_masks.value + return None + + @protection_storage_masks.setter + def protection_storage_masks(self, value): + if value is None: + self._protection_storage_masks = None + elif isinstance(value, six.integer_types): + self._protection_storage_masks = primitives.Integer( + value=value, + tag=enums.Tags.PROTECTION_STORAGE_MASKS + ) + else: + raise TypeError( + "The protection storage masks must be an integer." + ) + def read(self, input_buffer, kmip_version=enums.KMIPVersion.KMIP_1_0): """ Read the data encoding the Create request payload and decode it into @@ -158,6 +187,19 @@ class CreateRequestPayload(primitives.Struct): "attributes structure." ) + if self.is_tag_next( + enums.Tags.PROTECTION_STORAGE_MASKS, + local_buffer + ): + protection_storage_masks = primitives.Integer( + tag=enums.Tags.PROTECTION_STORAGE_MASKS + ) + protection_storage_masks.read( + local_buffer, + kmip_version=kmip_version + ) + self._protection_storage_masks = protection_storage_masks + self.is_oversized(local_buffer) def write(self, output_buffer, kmip_version=enums.KMIPVersion.KMIP_1_0): @@ -213,6 +255,12 @@ class CreateRequestPayload(primitives.Struct): "attribute field." ) + if self._protection_storage_masks: + self._protection_storage_masks.write( + local_buffer, + kmip_version=kmip_version + ) + self.length = local_buffer.length() super(CreateRequestPayload, self).write( output_buffer, @@ -226,6 +274,9 @@ class CreateRequestPayload(primitives.Struct): return False elif self.template_attribute != other.template_attribute: return False + elif self.protection_storage_masks != \ + other.protection_storage_masks: + return False else: return True else: @@ -240,7 +291,12 @@ class CreateRequestPayload(primitives.Struct): def __repr__(self): args = ", ".join([ "object_type={}".format(self.object_type), - "template_attribute={}".format(repr(self.template_attribute)) + "template_attribute={}".format(repr(self.template_attribute)), + "protection_storage_masks={}".format( + "{}".format( + repr(self.protection_storage_masks) + ) if self._protection_storage_masks else None + ) ]) return "CreateRequestPayload({})".format(args) @@ -248,7 +304,12 @@ class CreateRequestPayload(primitives.Struct): value = ", ".join( [ '"object_type": {}'.format(self.object_type), - '"template_attribute": {}'.format(self.template_attribute) + '"template_attribute": {}'.format(self.template_attribute), + '"protection_storage_masks": {}'.format( + "{}".format( + str(self.protection_storage_masks) + ) if self._protection_storage_masks else None + ) ] ) return '{' + value + '}' diff --git a/kmip/tests/unit/core/messages/payloads/test_create.py b/kmip/tests/unit/core/messages/payloads/test_create.py index 34430f7..de25293 100644 --- a/kmip/tests/unit/core/messages/payloads/test_create.py +++ b/kmip/tests/unit/core/messages/payloads/test_create.py @@ -76,13 +76,15 @@ class TestCreateRequestPayload(testtools.TestCase): # Cryptographic Algorithm - AES # Cryptographic Length - 128 # Cryptographic Usage Mask - Encrypt | Decrypt + # Protection Storage Masks - Software | Hardware self.full_encoding_with_attributes = utils.BytearrayStream( - b'\x42\x00\x79\x01\x00\x00\x00\x48' + b'\x42\x00\x79\x01\x00\x00\x00\x58' b'\x42\x00\x57\x05\x00\x00\x00\x04\x00\x00\x00\x02\x00\x00\x00\x00' b'\x42\x01\x25\x01\x00\x00\x00\x30' b'\x42\x00\x28\x05\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00\x00' b'\x42\x00\x2A\x02\x00\x00\x00\x04\x00\x00\x00\x80\x00\x00\x00\x00' b'\x42\x00\x2C\x02\x00\x00\x00\x04\x00\x00\x00\x0C\x00\x00\x00\x00' + b'\x42\x01\x5F\x02\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00\x00' ) # Encoding obtained from the KMIP 1.1 testing document, @@ -184,14 +186,40 @@ class TestCreateRequestPayload(testtools.TestCase): *args ) + def test_invalid_protection_storage_masks(self): + """ + Test that a TypeError is raised when an invalid value is used to set + the protection storage masks of a Create request payload. + """ + kwargs = {"protection_storage_masks": "invalid"} + self.assertRaisesRegex( + TypeError, + "The protection storage masks must be an integer.", + payloads.CreateRequestPayload, + **kwargs + ) + + args = ( + payloads.CreateRequestPayload(), + "protection_storage_masks", + "invalid" + ) + self.assertRaisesRegex( + TypeError, + "The protection storage masks must be an integer.", + setattr, + *args + ) + def test_read(self): """ Test that a Create request payload can be read from a data stream. """ payload = payloads.CreateRequestPayload() - self.assertEqual(None, payload.object_type) - self.assertEqual(None, payload.template_attribute) + self.assertIsNone(payload.object_type) + self.assertIsNone(payload.template_attribute) + self.assertIsNone(payload.protection_storage_masks) payload.read(self.full_encoding) @@ -237,6 +265,7 @@ class TestCreateRequestPayload(testtools.TestCase): ), payload.template_attribute ) + self.assertIsNone(payload.protection_storage_masks) def test_read_kmip_2_0(self): """ @@ -245,8 +274,9 @@ class TestCreateRequestPayload(testtools.TestCase): """ payload = payloads.CreateRequestPayload() - self.assertEqual(None, payload.object_type) - self.assertEqual(None, payload.template_attribute) + self.assertIsNone(payload.object_type) + self.assertIsNone(payload.template_attribute) + self.assertIsNone(payload.protection_storage_masks) payload.read( self.full_encoding_with_attributes, @@ -295,6 +325,7 @@ class TestCreateRequestPayload(testtools.TestCase): ), payload.template_attribute ) + self.assertEqual(3, payload.protection_storage_masks) def test_read_missing_object_type(self): """ @@ -447,6 +478,10 @@ class TestCreateRequestPayload(testtools.TestCase): ) ) ] + ), + protection_storage_masks=( + enums.ProtectionStorageMask.SOFTWARE.value | + enums.ProtectionStorageMask.HARDWARE.value ) ) @@ -566,12 +601,17 @@ class TestCreateRequestPayload(testtools.TestCase): ) ) ] + ), + protection_storage_masks=( + enums.ProtectionStorageMask.SOFTWARE.value | + enums.ProtectionStorageMask.HARDWARE.value ) ) self.assertEqual( "CreateRequestPayload(" "object_type=ObjectType.SYMMETRIC_KEY, " - "template_attribute=Struct())", + "template_attribute=Struct(), " + "protection_storage_masks=3)", repr(payload) ) @@ -603,12 +643,17 @@ class TestCreateRequestPayload(testtools.TestCase): ) ) ] + ), + protection_storage_masks=( + enums.ProtectionStorageMask.SOFTWARE.value | + enums.ProtectionStorageMask.HARDWARE.value ) ) self.assertEqual( '{' '"object_type": ObjectType.SYMMETRIC_KEY, ' - '"template_attribute": Struct()' + '"template_attribute": Struct(), ' + '"protection_storage_masks": 3' '}', str(payload) ) @@ -660,6 +705,10 @@ class TestCreateRequestPayload(testtools.TestCase): ) ) ] + ), + protection_storage_masks=( + enums.ProtectionStorageMask.SOFTWARE.value | + enums.ProtectionStorageMask.HARDWARE.value ) ) b = payloads.CreateRequestPayload( @@ -698,6 +747,10 @@ class TestCreateRequestPayload(testtools.TestCase): ) ) ] + ), + protection_storage_masks=( + enums.ProtectionStorageMask.SOFTWARE.value | + enums.ProtectionStorageMask.HARDWARE.value ) ) @@ -759,6 +812,27 @@ class TestCreateRequestPayload(testtools.TestCase): self.assertFalse(a == b) self.assertFalse(b == a) + def test_equal_on_not_equal_protection_storage_masks(self): + """ + Test that the equality operator returns False when comparing two Create + request payloads with different protection storage masks. + """ + a = payloads.CreateRequestPayload( + protection_storage_masks=( + enums.ProtectionStorageMask.SOFTWARE.value | + enums.ProtectionStorageMask.HARDWARE.value + ) + ) + b = payloads.CreateRequestPayload( + protection_storage_masks=( + enums.ProtectionStorageMask.ON_SYSTEM.value | + enums.ProtectionStorageMask.OFF_SYSTEM.value + ) + ) + + self.assertFalse(a == b) + self.assertFalse(b == a) + def test_equal_on_type_mismatch(self): """ Test that the equality operator returns False when comparing two Create @@ -817,6 +891,10 @@ class TestCreateRequestPayload(testtools.TestCase): ) ) ] + ), + protection_storage_masks=( + enums.ProtectionStorageMask.SOFTWARE.value | + enums.ProtectionStorageMask.HARDWARE.value ) ) b = payloads.CreateRequestPayload( @@ -855,6 +933,10 @@ class TestCreateRequestPayload(testtools.TestCase): ) ) ] + ), + protection_storage_masks=( + enums.ProtectionStorageMask.SOFTWARE.value | + enums.ProtectionStorageMask.HARDWARE.value ) ) @@ -916,6 +998,27 @@ class TestCreateRequestPayload(testtools.TestCase): self.assertTrue(a != b) self.assertTrue(b != a) + def test_not_equal_on_not_equal_protection_storage_masks(self): + """ + Test that the inequality operator returns True when comparing two + Create request payloads with different protection storage masks. + """ + a = payloads.CreateRequestPayload( + protection_storage_masks=( + enums.ProtectionStorageMask.SOFTWARE.value | + enums.ProtectionStorageMask.HARDWARE.value + ) + ) + b = payloads.CreateRequestPayload( + protection_storage_masks=( + enums.ProtectionStorageMask.ON_SYSTEM.value | + enums.ProtectionStorageMask.OFF_SYSTEM.value + ) + ) + + self.assertTrue(a != b) + self.assertTrue(b != a) + def test_not_equal_on_type_mismatch(self): """ Test that the inequality operator returns True when comparing two