From dd4a078cc15fbe5ff0336eb9188a3af62e338b96 Mon Sep 17 00:00:00 2001 From: Peter Hamilton Date: Thu, 22 Feb 2018 13:35:11 -0500 Subject: [PATCH] Update the ProtocolVersion implementation This change updates the implementation of the ProtocolVersion struct, bringing it inline with the current struct style. All uses of the struct have been updated to reflect these changes, as have the struct unit tests. --- kmip/core/messages/contents.py | 233 ++++++----- kmip/core/messages/messages.py | 3 - kmip/demos/units/discover_versions.py | 3 +- kmip/services/kmip_client.py | 2 +- kmip/services/server/engine.py | 10 +- kmip/services/server/policy.py | 84 ++-- kmip/services/server/session.py | 2 +- .../contents/test_protocol_version.py | 369 +++++++++++------- .../payloads/test_discover_versions.py | 12 +- .../tests/unit/core/messages/test_messages.py | 202 +++++----- .../tests/unit/services/server/test_engine.py | 44 +-- .../tests/unit/services/server/test_policy.py | 14 +- .../unit/services/server/test_session.py | 2 +- kmip/tests/unit/services/test_kmip_client.py | 4 +- 14 files changed, 557 insertions(+), 427 deletions(-) diff --git a/kmip/core/messages/contents.py b/kmip/core/messages/contents.py index 9027f7a..8d4ebb4 100644 --- a/kmip/core/messages/contents.py +++ b/kmip/core/messages/contents.py @@ -13,10 +13,13 @@ # License for the specific language governing permissions and limitations # under the License. +import six + from kmip.core import enums from kmip.core import objects from kmip.core import utils +from kmip.core import primitives from kmip.core.primitives import Struct from kmip.core.primitives import Integer from kmip.core.primitives import Enumeration @@ -26,105 +29,170 @@ from kmip.core.primitives import ByteString from kmip.core.primitives import DateTime -# 6.1 -class ProtocolVersion(Struct): +class ProtocolVersion(primitives.Struct): + """ + A struct representing a ProtocolVersion number. - class ProtocolVersionMajor(Integer): - def __init__(self, value=None): - super(ProtocolVersion.ProtocolVersionMajor, self).\ - __init__(value, enums.Tags.PROTOCOL_VERSION_MAJOR) + Attributes: + major: The major protocol version number. + minor: The minor protocol version number. + """ - class ProtocolVersionMinor(Integer): - def __init__(self, value=None): - super(ProtocolVersion.ProtocolVersionMinor, self).\ - __init__(value, enums.Tags.PROTOCOL_VERSION_MINOR) + def __init__(self, major=None, minor=None): + """ + Construct a ProtocolVersion struct. - def __init__(self, - protocol_version_major=None, - protocol_version_minor=None): + Args: + major (int): The major protocol version number. Optional, defaults + to None. + minor (int): The minor protocol version number. Optional, defaults + to None. + """ super(ProtocolVersion, self).__init__(enums.Tags.PROTOCOL_VERSION) - if protocol_version_major is None: - self.protocol_version_major = \ - ProtocolVersion.ProtocolVersionMajor() + self._major = None + self._minor = None + + self.major = major + self.minor = minor + + @property + def major(self): + if self._major: + return self._major.value else: - self.protocol_version_major = protocol_version_major + return None - if protocol_version_minor is None: - self.protocol_version_minor = \ - ProtocolVersion.ProtocolVersionMinor() + @major.setter + def major(self, value): + if value is None: + self._major = None + elif isinstance(value, six.integer_types): + self._major = primitives.Integer( + value=value, + tag=enums.Tags.PROTOCOL_VERSION_MAJOR + ) else: - self.protocol_version_minor = protocol_version_minor + raise TypeError( + "Major protocol version number must be an integer." + ) - self.validate() + @property + def minor(self): + if self._minor: + return self._minor.value + else: + return None - def read(self, istream): - super(ProtocolVersion, self).read(istream) - tstream = utils.BytearrayStream(istream.read(self.length)) + @minor.setter + def minor(self, value): + if value is None: + self._minor = None + elif isinstance(value, six.integer_types): + self._minor = primitives.Integer( + value=value, + tag=enums.Tags.PROTOCOL_VERSION_MINOR + ) + else: + raise TypeError( + "Minor protocol version number must be an integer." + ) - # Read the major and minor portions of the version number - self.protocol_version_major.read(tstream) - self.protocol_version_minor.read(tstream) + def read(self, input_stream): + """ + Read the data encoding the ProtocolVersion struct and decode it into + its constituent parts. - self.is_oversized(tstream) + Args: + input_stream (stream): A data stream containing encoded object + data, supporting a read method; usually a BytearrayStream + object. - def write(self, ostream): - tstream = utils.BytearrayStream() + Raises: + ValueError: Raised if either the major or minor protocol versions + are missing from the encoding. + """ + super(ProtocolVersion, self).read(input_stream) + local_stream = utils.BytearrayStream(input_stream.read(self.length)) - # Write the major and minor portions of the protocol version - self.protocol_version_major.write(tstream) - self.protocol_version_minor.write(tstream) + if self.is_tag_next(enums.Tags.PROTOCOL_VERSION_MAJOR, local_stream): + self._major = primitives.Integer( + tag=enums.Tags.PROTOCOL_VERSION_MAJOR + ) + self._major.read(local_stream) + else: + raise ValueError( + "Invalid encoding missing the major protocol version number." + ) - # Write the length and value of the protocol version - self.length = tstream.length() - super(ProtocolVersion, self).write(ostream) - ostream.write(tstream.buffer) + if self.is_tag_next(enums.Tags.PROTOCOL_VERSION_MINOR, local_stream): + self._minor = primitives.Integer( + tag=enums.Tags.PROTOCOL_VERSION_MINOR + ) + self._minor.read(local_stream) + else: + raise ValueError( + "Invalid encoding missing the minor protocol version number." + ) - def validate(self): - self.__validate() + self.is_oversized(local_stream) - def __validate(self): - if not isinstance(self.protocol_version_major, - ProtocolVersion.ProtocolVersionMajor): - msg = "invalid protocol version major" - msg += "; expected {0}, received {1}".format( - ProtocolVersion.ProtocolVersionMajor, - self.protocol_version_major) - raise TypeError(msg) + def write(self, output_stream): + """ + Write the data encoding the ProtocolVersion struct to a stream. - if not isinstance(self.protocol_version_minor, - ProtocolVersion.ProtocolVersionMinor): - msg = "invalid protocol version minor" - msg += "; expected {0}, received {1}".format( - ProtocolVersion.ProtocolVersionMinor, - self.protocol_version_minor) - raise TypeError(msg) + Args: + output_stream (stream): A data stream in which to encode object + data, supporting a write method; usually a BytearrayStream + object. + + Raises: + ValueError: Raised if the data attribute is not defined. + """ + local_stream = utils.BytearrayStream() + + if self._major: + self._major.write(local_stream) + else: + raise ValueError( + "Invalid struct missing the major protocol version number." + ) + + if self._minor: + self._minor.write(local_stream) + else: + raise ValueError( + "Invalid struct missing the minor protocol version number." + ) + + self.length = local_stream.length() + super(ProtocolVersion, self).write(output_stream) + output_stream.write(local_stream.buffer) def __eq__(self, other): if isinstance(other, ProtocolVersion): - if ((self.protocol_version_major == - other.protocol_version_major) and - (self.protocol_version_minor == - other.protocol_version_minor)): - return True - else: + if self.major != other.major: return False + elif self.minor != other.minor: + return False + else: + return True else: return NotImplemented def __ne__(self, other): if isinstance(other, ProtocolVersion): - return not self.__eq__(other) + return not (self == other) else: return NotImplemented def __lt__(self, other): if isinstance(other, ProtocolVersion): - if self.protocol_version_major < other.protocol_version_major: + if self.major < other.major: return True - elif self.protocol_version_major > other.protocol_version_major: + elif self.major > other.major: return False - elif self.protocol_version_minor < other.protocol_version_minor: + elif self.minor < other.minor: return True else: return False @@ -133,24 +201,16 @@ class ProtocolVersion(Struct): def __gt__(self, other): if isinstance(other, ProtocolVersion): - if self.protocol_version_major > other.protocol_version_major: - return True - elif self.protocol_version_major < other.protocol_version_major: + if (self == other) or (self < other): return False - elif self.protocol_version_minor > other.protocol_version_minor: - return True else: - return False + return True else: return NotImplemented def __le__(self, other): if isinstance(other, ProtocolVersion): - if self.protocol_version_major < other.protocol_version_major: - return True - elif self.protocol_version_major > other.protocol_version_major: - return False - elif self.protocol_version_minor <= other.protocol_version_minor: + if (self == other) or (self < other): return True else: return False @@ -159,11 +219,7 @@ class ProtocolVersion(Struct): def __ge__(self, other): if isinstance(other, ProtocolVersion): - if self.protocol_version_major > other.protocol_version_major: - return True - elif self.protocol_version_major < other.protocol_version_major: - return False - elif self.protocol_version_minor >= other.protocol_version_minor: + if (self == other) or (self > other): return True else: return False @@ -171,15 +227,14 @@ class ProtocolVersion(Struct): return NotImplemented def __repr__(self): - major = self.protocol_version_major.value - minor = self.protocol_version_minor.value - return "{0}.{1}".format(major, minor) + args = ", ".join([ + "major={}".format(self.major), + "minor={}".format(self.minor) + ]) + return "ProtocolVersion({})".format(args) - @classmethod - def create(cls, major, minor): - major = cls.ProtocolVersionMajor(major) - minor = cls.ProtocolVersionMinor(minor) - return ProtocolVersion(major, minor) + def __str__(self): + return "{}.{}".format(self.major, self.minor) # 6.2 diff --git a/kmip/core/messages/messages.py b/kmip/core/messages/messages.py index c63b837..11b2595 100644 --- a/kmip/core/messages/messages.py +++ b/kmip/core/messages/messages.py @@ -157,9 +157,6 @@ class ResponseHeader(Struct): ostream.write(tstream.buffer) def validate(self): - if self.protocol_version is not None: - # TODO (peter-hamilton) conduct type check - self.protocol_version.validate() if self.time_stamp is not None: # TODO (peter-hamilton) conduct type check self.time_stamp.validate() diff --git a/kmip/demos/units/discover_versions.py b/kmip/demos/units/discover_versions.py index 6155347..ad9a662 100644 --- a/kmip/demos/units/discover_versions.py +++ b/kmip/demos/units/discover_versions.py @@ -42,8 +42,7 @@ if __name__ == '__main__': if opts.protocol_versions is not None: for version in re.split(',| ', opts.protocol_versions): mm = re.split('\.', version) - protocol_versions.append(ProtocolVersion.create(int(mm[0]), - int(mm[1]))) + protocol_versions.append(ProtocolVersion(int(mm[0]), int(mm[1]))) # Build the client and connect to the server client = KMIPProxy(config=config) diff --git a/kmip/services/kmip_client.py b/kmip/services/kmip_client.py index e1889c5..d145403 100644 --- a/kmip/services/kmip_client.py +++ b/kmip/services/kmip_client.py @@ -1341,7 +1341,7 @@ class KMIPProxy(KMIP): return credential def _build_request_message(self, credential, batch_items): - protocol_version = ProtocolVersion.create(1, 2) + protocol_version = ProtocolVersion(1, 2) if credential is None: credential = self._build_credential() diff --git a/kmip/services/server/engine.py b/kmip/services/server/engine.py index 28f6c83..3d6cdf1 100644 --- a/kmip/services/server/engine.py +++ b/kmip/services/server/engine.py @@ -105,9 +105,9 @@ class KmipEngine(object): self._id_placeholder = None self._protocol_versions = [ - contents.ProtocolVersion.create(1, 2), - contents.ProtocolVersion.create(1, 1), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 2), + contents.ProtocolVersion(1, 1), + contents.ProtocolVersion(1, 0) ] self._protocol_version = self._protocol_versions[0] @@ -2001,11 +2001,11 @@ class KmipEngine(object): contents.Operation(enums.Operation.QUERY) ]) - if self._protocol_version >= contents.ProtocolVersion.create(1, 1): + if self._protocol_version >= contents.ProtocolVersion(1, 1): operations.extend([ contents.Operation(enums.Operation.DISCOVER_VERSIONS) ]) - if self._protocol_version >= contents.ProtocolVersion.create(1, 2): + if self._protocol_version >= contents.ProtocolVersion(1, 2): operations.extend([ contents.Operation(enums.Operation.ENCRYPT), contents.Operation(enums.Operation.DECRYPT), diff --git a/kmip/services/server/policy.py b/kmip/services/server/policy.py index 760081c..a266648 100644 --- a/kmip/services/server/policy.py +++ b/kmip/services/server/policy.py @@ -157,7 +157,7 @@ class AttributePolicy(object): enums.ObjectType.SECRET_DATA, enums.ObjectType.OPAQUE_DATA ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Name': AttributeRuleSet( False, @@ -181,7 +181,7 @@ class AttributePolicy(object): enums.ObjectType.SECRET_DATA, enums.ObjectType.OPAQUE_DATA ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Object Type': AttributeRuleSet( True, @@ -210,7 +210,7 @@ class AttributePolicy(object): enums.ObjectType.SECRET_DATA, enums.ObjectType.OPAQUE_DATA ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Cryptographic Algorithm': AttributeRuleSet( True, @@ -237,7 +237,7 @@ class AttributePolicy(object): enums.ObjectType.SPLIT_KEY, enums.ObjectType.TEMPLATE ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Cryptographic Length': AttributeRuleSet( True, @@ -264,7 +264,7 @@ class AttributePolicy(object): enums.ObjectType.SPLIT_KEY, enums.ObjectType.TEMPLATE ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Cryptographic Parameters': AttributeRuleSet( False, @@ -286,7 +286,7 @@ class AttributePolicy(object): enums.ObjectType.SPLIT_KEY, enums.ObjectType.TEMPLATE ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Cryptographic Domain Parameters': AttributeRuleSet( False, @@ -304,7 +304,7 @@ class AttributePolicy(object): enums.ObjectType.PRIVATE_KEY, enums.ObjectType.TEMPLATE ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Certificate Type': AttributeRuleSet( True, @@ -321,7 +321,7 @@ class AttributePolicy(object): ( enums.ObjectType.CERTIFICATE, ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Certificate Length': AttributeRuleSet( True, @@ -338,7 +338,7 @@ class AttributePolicy(object): ( enums.ObjectType.CERTIFICATE, ), - contents.ProtocolVersion.create(1, 1) + contents.ProtocolVersion(1, 1) ), 'X.509 Certificate Identifier': AttributeRuleSet( True, @@ -356,7 +356,7 @@ class AttributePolicy(object): # TODO (peterhamilton) Enforce only on X.509 certificates enums.ObjectType.CERTIFICATE, ), - contents.ProtocolVersion.create(1, 1) + contents.ProtocolVersion(1, 1) ), 'X.509 Certificate Subject': AttributeRuleSet( True, @@ -374,7 +374,7 @@ class AttributePolicy(object): # TODO (peterhamilton) Enforce only on X.509 certificates enums.ObjectType.CERTIFICATE, ), - contents.ProtocolVersion.create(1, 1) + contents.ProtocolVersion(1, 1) ), 'X.509 Certificate Issuer': AttributeRuleSet( True, @@ -392,7 +392,7 @@ class AttributePolicy(object): # TODO (peterhamilton) Enforce only on X.509 certificates enums.ObjectType.CERTIFICATE, ), - contents.ProtocolVersion.create(1, 1) + contents.ProtocolVersion(1, 1) ), 'Certificate Identifier': AttributeRuleSet( True, @@ -409,8 +409,8 @@ class AttributePolicy(object): ( enums.ObjectType.CERTIFICATE, ), - contents.ProtocolVersion.create(1, 0), - contents.ProtocolVersion.create(1, 1) + contents.ProtocolVersion(1, 0), + contents.ProtocolVersion(1, 1) ), 'Certificate Subject': AttributeRuleSet( True, @@ -427,8 +427,8 @@ class AttributePolicy(object): ( enums.ObjectType.CERTIFICATE, ), - contents.ProtocolVersion.create(1, 0), - contents.ProtocolVersion.create(1, 1) + contents.ProtocolVersion(1, 0), + contents.ProtocolVersion(1, 1) ), 'Certificate Issuer': AttributeRuleSet( True, @@ -445,8 +445,8 @@ class AttributePolicy(object): ( enums.ObjectType.CERTIFICATE, ), - contents.ProtocolVersion.create(1, 0), - contents.ProtocolVersion.create(1, 1) + contents.ProtocolVersion(1, 0), + contents.ProtocolVersion(1, 1) ), 'Digital Signature Algorithm': AttributeRuleSet( True, @@ -464,7 +464,7 @@ class AttributePolicy(object): ( enums.ObjectType.CERTIFICATE, ), - contents.ProtocolVersion.create(1, 1) + contents.ProtocolVersion(1, 1) ), 'Digest': AttributeRuleSet( True, # If the server has access to the data @@ -492,7 +492,7 @@ class AttributePolicy(object): enums.ObjectType.SECRET_DATA, enums.ObjectType.OPAQUE_DATA ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Operation Policy Name': AttributeRuleSet( False, @@ -521,7 +521,7 @@ class AttributePolicy(object): enums.ObjectType.SECRET_DATA, enums.ObjectType.OPAQUE_DATA ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Cryptographic Usage Mask': AttributeRuleSet( True, @@ -549,7 +549,7 @@ class AttributePolicy(object): enums.ObjectType.TEMPLATE, enums.ObjectType.SECRET_DATA ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Lease Time': AttributeRuleSet( False, @@ -576,7 +576,7 @@ class AttributePolicy(object): enums.ObjectType.SPLIT_KEY, enums.ObjectType.SECRET_DATA ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Usage Limits': AttributeRuleSet( False, @@ -601,7 +601,7 @@ class AttributePolicy(object): enums.ObjectType.SPLIT_KEY, enums.ObjectType.TEMPLATE ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'State': AttributeRuleSet( True, @@ -631,7 +631,7 @@ class AttributePolicy(object): enums.ObjectType.SPLIT_KEY, enums.ObjectType.SECRET_DATA ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Initial Date': AttributeRuleSet( True, @@ -660,7 +660,7 @@ class AttributePolicy(object): enums.ObjectType.SECRET_DATA, enums.ObjectType.OPAQUE_DATA ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Activation Date': AttributeRuleSet( False, @@ -689,7 +689,7 @@ class AttributePolicy(object): enums.ObjectType.TEMPLATE, enums.ObjectType.SECRET_DATA ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Process Start Date': AttributeRuleSet( False, @@ -710,7 +710,7 @@ class AttributePolicy(object): enums.ObjectType.SPLIT_KEY, enums.ObjectType.TEMPLATE ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Protect Stop Date': AttributeRuleSet( False, @@ -731,7 +731,7 @@ class AttributePolicy(object): enums.ObjectType.SPLIT_KEY, enums.ObjectType.TEMPLATE ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Deactivation Date': AttributeRuleSet( False, @@ -760,7 +760,7 @@ class AttributePolicy(object): enums.ObjectType.TEMPLATE, enums.ObjectType.SECRET_DATA ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Destroy Date': AttributeRuleSet( False, @@ -781,7 +781,7 @@ class AttributePolicy(object): enums.ObjectType.SECRET_DATA, enums.ObjectType.OPAQUE_DATA ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Compromise Occurrence Date': AttributeRuleSet( False, @@ -802,7 +802,7 @@ class AttributePolicy(object): enums.ObjectType.SECRET_DATA, enums.ObjectType.OPAQUE_DATA ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Compromise Date': AttributeRuleSet( False, @@ -823,7 +823,7 @@ class AttributePolicy(object): enums.ObjectType.SECRET_DATA, enums.ObjectType.OPAQUE_DATA ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Revocation Reason': AttributeRuleSet( False, @@ -844,7 +844,7 @@ class AttributePolicy(object): enums.ObjectType.SECRET_DATA, enums.ObjectType.OPAQUE_DATA ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Archive Date': AttributeRuleSet( False, @@ -866,7 +866,7 @@ class AttributePolicy(object): enums.ObjectType.SECRET_DATA, enums.ObjectType.OPAQUE_DATA ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Object Group': AttributeRuleSet( False, @@ -895,7 +895,7 @@ class AttributePolicy(object): enums.ObjectType.SECRET_DATA, enums.ObjectType.OPAQUE_DATA ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Fresh': AttributeRuleSet( False, @@ -924,7 +924,7 @@ class AttributePolicy(object): enums.ObjectType.SECRET_DATA, enums.ObjectType.OPAQUE_DATA ), - contents.ProtocolVersion.create(1, 1) + contents.ProtocolVersion(1, 1) ), 'Link': AttributeRuleSet( False, @@ -951,7 +951,7 @@ class AttributePolicy(object): enums.ObjectType.SECRET_DATA, enums.ObjectType.OPAQUE_DATA ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Application Specific Information': AttributeRuleSet( False, @@ -975,7 +975,7 @@ class AttributePolicy(object): enums.ObjectType.SECRET_DATA, enums.ObjectType.OPAQUE_DATA ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Contact Information': AttributeRuleSet( False, @@ -1004,7 +1004,7 @@ class AttributePolicy(object): enums.ObjectType.SECRET_DATA, enums.ObjectType.OPAQUE_DATA ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Last Change Date': AttributeRuleSet( True, @@ -1042,7 +1042,7 @@ class AttributePolicy(object): enums.ObjectType.SECRET_DATA, enums.ObjectType.OPAQUE_DATA ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), 'Custom Attribute': AttributeRuleSet( False, @@ -1074,7 +1074,7 @@ class AttributePolicy(object): enums.ObjectType.SECRET_DATA, enums.ObjectType.OPAQUE_DATA ), - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ), } diff --git a/kmip/services/server/session.py b/kmip/services/server/session.py index 03291f3..cde057d 100644 --- a/kmip/services/server/session.py +++ b/kmip/services/server/session.py @@ -176,7 +176,7 @@ class KmipSession(threading.Thread): self._logger.warning("Failure parsing request message.") self._logger.exception(e) response = self._engine.build_error_response( - contents.ProtocolVersion.create(1, 0), + contents.ProtocolVersion(1, 0), enums.ResultReason.INVALID_MESSAGE, "Error parsing request message. See server logs for more " "information." diff --git a/kmip/tests/unit/core/messages/contents/test_protocol_version.py b/kmip/tests/unit/core/messages/contents/test_protocol_version.py index e84901c..448225d 100644 --- a/kmip/tests/unit/core/messages/contents/test_protocol_version.py +++ b/kmip/tests/unit/core/messages/contents/test_protocol_version.py @@ -13,171 +13,243 @@ # License for the specific language governing permissions and limitations # under the License. -from testtools import TestCase +import testtools -from kmip.core.messages.contents import ProtocolVersion -from kmip.core.utils import BytearrayStream +from kmip.core.messages import contents +from kmip.core import utils -class TestProtocolVersion(TestCase): +class TestProtocolVersion(testtools.TestCase): def setUp(self): super(TestProtocolVersion, self).setUp() - self.major_default = ProtocolVersion.ProtocolVersionMajor() - self.minor_default = ProtocolVersion.ProtocolVersionMinor() - self.major = ProtocolVersion.ProtocolVersionMajor(1) - self.minor = ProtocolVersion.ProtocolVersionMinor(1) + # Encoding obtained from the KMIP 1.1 testing document, Section 3.1.1. + # + # This encoding matches the following set of values: + # ProtocolVersion + # ProtocolVersionMajor - 1 + # ProtocolVersionMinor - 1 - self.encoding_default = BytearrayStream(( - b'\x42\x00\x69\x01\x00\x00\x00\x20\x42\x00\x6A\x02\x00\x00\x00\x04' - b'\x00\x00\x00\x00\x00\x00\x00\x00\x42\x00\x6B\x02\x00\x00\x00\x04' - b'\x00\x00\x00\x00\x00\x00\x00\x00')) - self.encoding = BytearrayStream(( - b'\x42\x00\x69\x01\x00\x00\x00\x20\x42\x00\x6A\x02\x00\x00\x00\x04' - b'\x00\x00\x00\x01\x00\x00\x00\x00\x42\x00\x6B\x02\x00\x00\x00\x04' - b'\x00\x00\x00\x01\x00\x00\x00\x00')) + self.full_encoding = utils.BytearrayStream( + b'\x42\x00\x69\x01\x00\x00\x00\x20' + b'\x42\x00\x6A\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x00' + b'\x42\x00\x6B\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x00' + ) + + self.encoding_no_major_number = utils.BytearrayStream( + b'\x42\x00\x69\x01\x00\x00\x00\x10' + b'\x42\x00\x6B\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x00' + ) + + self.encoding_no_minor_number = utils.BytearrayStream( + b'\x42\x00\x69\x01\x00\x00\x00\x10' + b'\x42\x00\x6A\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x00' + ) def tearDown(self): super(TestProtocolVersion, self).tearDown() - def _test_init(self, protocol_version_major, protocol_version_minor): - protocol_version = ProtocolVersion( - protocol_version_major, protocol_version_minor) + def test_init(self): + """ + Test that a ProtocolVersion struct can be constructed with no + arguments. + """ + struct = contents.ProtocolVersion() - if protocol_version_major is None: - self.assertEqual(ProtocolVersion.ProtocolVersionMajor(), - protocol_version.protocol_version_major) - else: - self.assertEqual(protocol_version_major, - protocol_version.protocol_version_major) - - if protocol_version_minor is None: - self.assertEqual(ProtocolVersion.ProtocolVersionMinor(), - protocol_version.protocol_version_minor) - else: - self.assertEqual(protocol_version_minor, - protocol_version.protocol_version_minor) - - def test_init_with_none(self): - self._test_init(None, None) + self.assertEqual(None, struct.major) + self.assertEqual(None, struct.minor) def test_init_with_args(self): - major = ProtocolVersion.ProtocolVersionMajor(1) - minor = ProtocolVersion.ProtocolVersionMinor(0) + """ + Test that a ProtocolVersion struct can be constructed with valid + values. + """ + struct = contents.ProtocolVersion(1, 1) - self._test_init(major, minor) - - def test_validate_on_invalid_protocol_version_major(self): - major = "invalid" - minor = ProtocolVersion.ProtocolVersionMinor(0) - args = [major, minor] + self.assertEqual(1, struct.major) + self.assertEqual(1, struct.minor) + def test_invalid_protocol_version_major(self): + """ + Test that a TypeError is raised when an invalid value is used to set + the major protocol version number of a ProtocolVersion struct. + """ + struct = contents.ProtocolVersion() + args = (struct, 'major', 'invalid') self.assertRaisesRegexp( - TypeError, "invalid protocol version major", self._test_init, - *args) - - def test_validate_on_invalid_protocol_version_minor(self): - major = ProtocolVersion.ProtocolVersionMajor(1) - minor = "invalid" - args = [major, minor] + TypeError, + "Major protocol version number must be an integer.", + setattr, + *args + ) + def test_invalid_protocol_version_minor(self): + """ + Test that a TypeError is raised when an invalid value is used to set + the minor protocol version number of a ProtocolVersion struct. + """ + struct = contents.ProtocolVersion() + args = (struct, 'minor', 'invalid') self.assertRaisesRegexp( - TypeError, "invalid protocol version minor", self._test_init, - *args) + TypeError, + "Minor protocol version number must be an integer.", + setattr, + *args + ) - def _test_read(self, stream, major, minor): - protocol_version = ProtocolVersion() - protocol_version.read(stream) + def test_read(self): + """ + Test that a ProtocolVersion struct can be read from a data stream. + """ + struct = contents.ProtocolVersion() - msg = "protocol version major decoding mismatch" - msg += "; expected {0}, received {1}".format( - major, protocol_version.protocol_version_major) - self.assertEqual(major, protocol_version.protocol_version_major, msg) + self.assertEqual(None, struct.major) + self.assertEqual(None, struct.minor) - msg = "protocol version minor decoding mismatch" - msg += "; expected {0}, received {1}".format( - minor, protocol_version.protocol_version_minor) - self.assertEqual(minor, protocol_version.protocol_version_minor, msg) + struct.read(self.full_encoding) - def test_read_with_none(self): - self._test_read(self.encoding_default, self.major_default, - self.minor_default) + self.assertEqual(1, struct.major) + self.assertEqual(1, struct.minor) - def test_read_with_args(self): - self._test_read(self.encoding, self.major, self.minor) + def test_read_missing_major_number(self): + """ + Test that a ValueError gets raised when a required ProtocolVersion + struct attribute is missing from the struct encoding. + """ + struct = contents.ProtocolVersion() + args = (self.encoding_no_major_number, ) + self.assertRaisesRegexp( + ValueError, + "Invalid encoding missing the major protocol version number.", + struct.read, + *args + ) - def _test_write(self, stream_expected, major, minor): - stream_observed = BytearrayStream() - protocol_version = ProtocolVersion(major, minor) - protocol_version.write(stream_observed) + def test_read_missing_minor_number(self): + """ + Test that a ValueError gets raised when a required ProtocolVersion + struct attribute is missing from the struct encoding. + """ + struct = contents.ProtocolVersion() + args = (self.encoding_no_minor_number, ) + self.assertRaisesRegexp( + ValueError, + "Invalid encoding missing the minor protocol version number.", + struct.read, + *args + ) - length_expected = len(stream_expected) - length_observed = len(stream_observed) + def test_write(self): + """ + Test that a ProtocolVersion struct can be written to a data stream. + """ + struct = contents.ProtocolVersion(1, 1) + stream = utils.BytearrayStream() + struct.write(stream) - msg = "encoding lengths not equal" - msg += "; expected {0}, received {1}".format( - length_expected, length_observed) - self.assertEqual(length_expected, length_observed, msg) + self.assertEqual(len(self.full_encoding), len(stream)) + self.assertEqual(str(self.full_encoding), str(stream)) - msg = "encoding mismatch" - msg += ";\nexpected:\n{0}\nreceived:\n{1}".format( - stream_expected, stream_observed) + def test_write_missing_major_number(self): + """ + Test that a ValueError gets raised when a required ProtocolVersion + struct attribute is missing when encoding the struct. + """ + struct = contents.ProtocolVersion(None, 1) + stream = utils.BytearrayStream() + args = (stream, ) + self.assertRaisesRegexp( + ValueError, + "Invalid struct missing the major protocol version number.", + struct.write, + *args + ) - self.assertEqual(stream_expected, stream_observed, msg) - - def test_write_with_none(self): - self._test_write(self.encoding_default, self.major_default, - self.minor_default) - - def test_write_with_args(self): - self._test_write(self.encoding, self.major, self.minor) + def test_write_missing_minor_number(self): + """ + Test that a ValueError gets raised when a required ProtocolVersion + struct attribute is missing when encoding the struct. + """ + struct = contents.ProtocolVersion(1, None) + stream = utils.BytearrayStream() + args = (stream, ) + self.assertRaisesRegexp( + ValueError, + "Invalid struct missing the minor protocol version number.", + struct.write, + *args + ) def test_equal_on_equal(self): - a = ProtocolVersion.create(1, 0) - b = ProtocolVersion.create(1, 0) + """ + Test that the equality operator returns True when comparing two + ProtocolVersion structs with the same data. + """ + a = contents.ProtocolVersion(1, 0) + b = contents.ProtocolVersion(1, 0) self.assertTrue(a == b) def test_equal_on_not_equal(self): - a = ProtocolVersion.create(1, 0) - b = ProtocolVersion.create(0, 1) + """ + Test that the equality operator returns False when comparing two + ProtocolVersion structs with different data. + """ + a = contents.ProtocolVersion(1, 0) + b = contents.ProtocolVersion(0, 1) self.assertFalse(a == b) def test_equal_on_type_mismatch(self): - a = ProtocolVersion.create(1, 0) + """ + Test that the equality operator returns False when comparing two + ProtocolVersion structs with different types. + """ + a = contents.ProtocolVersion(1, 0) b = "invalid" self.assertFalse(a == b) def test_not_equal_on_equal(self): - a = ProtocolVersion.create(1, 0) - b = ProtocolVersion.create(1, 0) + """ + Test that the inequality operator returns False when comparing two + ProtocolVersion structs with the same data. + """ + a = contents.ProtocolVersion(1, 0) + b = contents.ProtocolVersion(1, 0) self.assertFalse(a != b) def test_not_equal_on_not_equal(self): - a = ProtocolVersion.create(1, 0) - b = ProtocolVersion.create(0, 1) + """ + Test that the inequality operator returns True when comparing two + ProtocolVersion structs with different data. + """ + a = contents.ProtocolVersion(1, 0) + b = contents.ProtocolVersion(0, 1) self.assertTrue(a != b) def test_not_equal_on_type_mismatch(self): - a = ProtocolVersion.create(1, 0) + """ + Test that the inequality operator returns True when comparing two + ProtocolVersion structs with different types. + """ + a = contents.ProtocolVersion(1, 0) b = "invalid" self.assertTrue(a != b) def test_less_than(self): """ - Test that the less than operator returns True/False when comparing - two different ProtocolVersions. + Test that the less than operator correctly returns True/False when + comparing two different ProtocolVersions. """ - a = ProtocolVersion.create(1, 0) - b = ProtocolVersion.create(1, 1) - c = ProtocolVersion.create(2, 0) - d = ProtocolVersion.create(0, 2) + a = contents.ProtocolVersion(1, 0) + b = contents.ProtocolVersion(1, 1) + c = contents.ProtocolVersion(2, 0) + d = contents.ProtocolVersion(0, 2) self.assertTrue(a < b) self.assertFalse(b < a) @@ -187,15 +259,19 @@ class TestProtocolVersion(TestCase): self.assertFalse(c < d) self.assertTrue(d < c) + # A direct call to __lt__ is required here due to differences in how + # Python 2 and Python 3 treat comparison operators. + self.assertEqual(NotImplemented, a.__lt__('invalid')) + def test_greater_than(self): """ - Test that the greater than operator returns True/False when + Test that the greater than operator correctly returns True/False when comparing two different ProtocolVersions. """ - a = ProtocolVersion.create(1, 0) - b = ProtocolVersion.create(1, 1) - c = ProtocolVersion.create(2, 0) - d = ProtocolVersion.create(0, 2) + a = contents.ProtocolVersion(1, 0) + b = contents.ProtocolVersion(1, 1) + c = contents.ProtocolVersion(2, 0) + d = contents.ProtocolVersion(0, 2) self.assertFalse(a > b) self.assertTrue(b > a) @@ -205,15 +281,19 @@ class TestProtocolVersion(TestCase): self.assertTrue(c > d) self.assertFalse(d > c) + # A direct call to __gt__ is required here due to differences in how + # Python 2 and Python 3 treat comparison operators. + self.assertEqual(NotImplemented, a.__gt__('invalid')) + def test_less_than_or_equal(self): """ - Test that the less than or equal operator returns True/False when - comparing two different ProtocolVersions. + Test that the less than or equal operator correctly returns True/False + when comparing two different ProtocolVersions. """ - a = ProtocolVersion.create(1, 0) - b = ProtocolVersion.create(1, 1) - c = ProtocolVersion.create(2, 0) - d = ProtocolVersion.create(0, 2) + a = contents.ProtocolVersion(1, 0) + b = contents.ProtocolVersion(1, 1) + c = contents.ProtocolVersion(2, 0) + d = contents.ProtocolVersion(0, 2) self.assertTrue(a <= b) self.assertFalse(b <= a) @@ -223,15 +303,19 @@ class TestProtocolVersion(TestCase): self.assertFalse(c <= d) self.assertTrue(d <= c) + # A direct call to __le__ is required here due to differences in how + # Python 2 and Python 3 treat comparison operators. + self.assertEqual(NotImplemented, a.__le__('invalid')) + def test_greater_than_or_equal(self): """ - Test that the greater than or equal operator returns True/False when - comparing two different ProtocolVersions. + Test that the greater than or equal operator correctly returns + True/False when comparing two different ProtocolVersions. """ - a = ProtocolVersion.create(1, 0) - b = ProtocolVersion.create(1, 1) - c = ProtocolVersion.create(2, 0) - d = ProtocolVersion.create(0, 2) + a = contents.ProtocolVersion(1, 0) + b = contents.ProtocolVersion(1, 1) + c = contents.ProtocolVersion(2, 0) + d = contents.ProtocolVersion(0, 2) self.assertFalse(a >= b) self.assertTrue(b >= a) @@ -241,30 +325,25 @@ class TestProtocolVersion(TestCase): self.assertTrue(c >= d) self.assertFalse(d >= c) + # A direct call to __ge__ is required here due to differences in how + # Python 2 and Python 3 treat comparison operators. + self.assertEqual(NotImplemented, a.__ge__('invalid')) + def test_repr(self): - a = ProtocolVersion.create(1, 0) + """ + Test that repr can be applied to a ProtocolVersion struct. + """ + struct = contents.ProtocolVersion(1, 0) - self.assertEqual("1.0", "{0}".format(a)) + self.assertEqual( + "ProtocolVersion(major=1, minor=0)", + "{}".format(repr(struct)) + ) - def _test_create(self, major, minor): - protocol_version = ProtocolVersion.create(major, minor) + def test_str(self): + """ + Test that str can be applied to a ProtocolVersion struct. + """ + struct = contents.ProtocolVersion(1, 0) - if major is None: - expected = ProtocolVersion.ProtocolVersionMajor() - else: - expected = ProtocolVersion.ProtocolVersionMajor(major) - - self.assertEqual(expected, protocol_version.protocol_version_major) - - if minor is None: - expected = ProtocolVersion.ProtocolVersionMinor() - else: - expected = ProtocolVersion.ProtocolVersionMinor(minor) - - self.assertEqual(expected, protocol_version.protocol_version_minor) - - def test_create_with_none(self): - self._test_create(None, None) - - def test_create_with_args(self): - self._test_create(1, 0) + self.assertEqual("1.0", str(struct)) diff --git a/kmip/tests/unit/core/messages/payloads/test_discover_versions.py b/kmip/tests/unit/core/messages/payloads/test_discover_versions.py index 2bc5c66..4bc9189 100644 --- a/kmip/tests/unit/core/messages/payloads/test_discover_versions.py +++ b/kmip/tests/unit/core/messages/payloads/test_discover_versions.py @@ -30,10 +30,10 @@ class TestDiscoverVersionsRequestPayload(TestCase): self.protocol_versions_empty = list() self.protocol_versions_one = list() - self.protocol_versions_one.append(ProtocolVersion.create(1, 0)) + self.protocol_versions_one.append(ProtocolVersion(1, 0)) self.protocol_versions_two = list() - self.protocol_versions_two.append(ProtocolVersion.create(1, 1)) - self.protocol_versions_two.append(ProtocolVersion.create(1, 0)) + self.protocol_versions_two.append(ProtocolVersion(1, 1)) + self.protocol_versions_two.append(ProtocolVersion(1, 0)) self.encoding_empty = utils.BytearrayStream(( b'\x42\x00\x79\x01\x00\x00\x00\x00')) @@ -157,10 +157,10 @@ class TestDiscoverVersionsResponsePayload(TestCase): self.protocol_versions_empty = list() self.protocol_versions_one = list() - self.protocol_versions_one.append(ProtocolVersion.create(1, 0)) + self.protocol_versions_one.append(ProtocolVersion(1, 0)) self.protocol_versions_two = list() - self.protocol_versions_two.append(ProtocolVersion.create(1, 1)) - self.protocol_versions_two.append(ProtocolVersion.create(1, 0)) + self.protocol_versions_two.append(ProtocolVersion(1, 1)) + self.protocol_versions_two.append(ProtocolVersion(1, 0)) self.encoding_empty = utils.BytearrayStream(( b'\x42\x00\x7C\x01\x00\x00\x00\x00')) diff --git a/kmip/tests/unit/core/messages/test_messages.py b/kmip/tests/unit/core/messages/test_messages.py index 20a6633..7520590 100644 --- a/kmip/tests/unit/core/messages/test_messages.py +++ b/kmip/tests/unit/core/messages/test_messages.py @@ -183,25 +183,25 @@ class TestRequestMessage(TestCase): msg.format(contents.ProtocolVersion, type(protocol_version))) - protocol_version_major = protocol_version.protocol_version_major + protocol_version_major = protocol_version.major msg = "Bad protocol version major type: expected {0}, received {1}" - exp_type = contents.ProtocolVersion.ProtocolVersionMajor + exp_type = int rcv_type = type(protocol_version_major) self.assertIsInstance(protocol_version_major, exp_type, msg.format(exp_type, rcv_type)) msg = "Bad protocol version major value: expected {0}, received {1}" - self.assertEqual(1, protocol_version_major.value, - msg.format(1, protocol_version_major.value)) + self.assertEqual(1, protocol_version_major, + msg.format(1, protocol_version_major)) - protocol_version_minor = protocol_version.protocol_version_minor + protocol_version_minor = protocol_version.minor msg = "Bad protocol version minor type: expected {0}, received {1}" - exp_type = contents.ProtocolVersion.ProtocolVersionMinor + exp_type = int rcv_type = type(protocol_version_minor) self.assertIsInstance(protocol_version_minor, exp_type, msg.format(exp_type, rcv_type)) msg = "Bad protocol version minor value: expected {0}, received {1}" - self.assertEqual(1, protocol_version_minor.value, - msg.format(1, protocol_version_minor.value)) + self.assertEqual(1, protocol_version_minor, + msg.format(1, protocol_version_minor)) batch_count = request_header.batch_count msg = "Bad batch count type: expected {0}, received {1}" @@ -352,7 +352,7 @@ class TestRequestMessage(TestCase): exp_value, attribute_value.value)) def test_create_request_write(self): - prot_ver = contents.ProtocolVersion.create(1, 1) + prot_ver = contents.ProtocolVersion(1, 1) batch_count = contents.BatchCount(1) request_header = messages.RequestHeader(protocol_version=prot_ver, @@ -414,25 +414,25 @@ class TestRequestMessage(TestCase): msg.format(contents.ProtocolVersion, type(protocol_version))) - protocol_version_major = protocol_version.protocol_version_major + protocol_version_major = protocol_version.major msg = "Bad protocol version major type: expected {0}, received {1}" - exp_type = contents.ProtocolVersion.ProtocolVersionMajor + exp_type = int rcv_type = type(protocol_version_major) self.assertIsInstance(protocol_version_major, exp_type, msg.format(exp_type, rcv_type)) msg = "Bad protocol version major value: expected {0}, received {1}" - self.assertEqual(1, protocol_version_major.value, - msg.format(1, protocol_version_major.value)) + self.assertEqual(1, protocol_version_major, + msg.format(1, protocol_version_major)) - protocol_version_minor = protocol_version.protocol_version_minor + protocol_version_minor = protocol_version.minor msg = "Bad protocol version minor type: expected {0}, received {1}" - exp_type = contents.ProtocolVersion.ProtocolVersionMinor + exp_type = int rcv_type = type(protocol_version_minor) self.assertIsInstance(protocol_version_minor, exp_type, msg.format(exp_type, rcv_type)) msg = "Bad protocol version minor value: expected {0}, received {1}" - self.assertEqual(1, protocol_version_minor.value, - msg.format(1, protocol_version_minor.value)) + self.assertEqual(1, protocol_version_minor, + msg.format(1, protocol_version_minor)) batch_count = request_header.batch_count msg = "Bad batch count type: expected {0}, received {1}" @@ -486,7 +486,7 @@ class TestRequestMessage(TestCase): ) def test_get_request_write(self): - prot_ver = contents.ProtocolVersion.create(1, 1) + prot_ver = contents.ProtocolVersion(1, 1) batch_count = contents.BatchCount(1) req_header = messages.RequestHeader(protocol_version=prot_ver, @@ -532,25 +532,25 @@ class TestRequestMessage(TestCase): msg.format(contents.ProtocolVersion, type(protocol_version))) - protocol_version_major = protocol_version.protocol_version_major + protocol_version_major = protocol_version.major msg = "Bad protocol version major type: expected {0}, received {1}" - exp_type = contents.ProtocolVersion.ProtocolVersionMajor + exp_type = int rcv_type = type(protocol_version_major) self.assertIsInstance(protocol_version_major, exp_type, msg.format(exp_type, rcv_type)) msg = "Bad protocol version major value: expected {0}, received {1}" - self.assertEqual(1, protocol_version_major.value, - msg.format(1, protocol_version_major.value)) + self.assertEqual(1, protocol_version_major, + msg.format(1, protocol_version_major)) - protocol_version_minor = protocol_version.protocol_version_minor + protocol_version_minor = protocol_version.minor msg = "Bad protocol version minor type: expected {0}, received {1}" - exp_type = contents.ProtocolVersion.ProtocolVersionMinor + exp_type = int rcv_type = type(protocol_version_minor) self.assertIsInstance(protocol_version_minor, exp_type, msg.format(exp_type, rcv_type)) msg = "Bad protocol version minor value: expected {0}, received {1}" - self.assertEqual(1, protocol_version_minor.value, - msg.format(1, protocol_version_minor.value)) + self.assertEqual(1, protocol_version_minor, + msg.format(1, protocol_version_minor)) batch_count = request_header.batch_count msg = "Bad batch count type: expected {0}, received {1}" @@ -605,7 +605,7 @@ class TestRequestMessage(TestCase): msg.format(exp_value, rcv_value)) def test_destroy_request_write(self): - prot_ver = contents.ProtocolVersion.create(1, 1) + prot_ver = contents.ProtocolVersion(1, 1) batch_count = contents.BatchCount(1) req_header = messages.RequestHeader(protocol_version=prot_ver, @@ -651,25 +651,25 @@ class TestRequestMessage(TestCase): msg.format(contents.ProtocolVersion, type(protocol_version))) - protocol_version_major = protocol_version.protocol_version_major + protocol_version_major = protocol_version.major msg = "Bad protocol version major type: expected {0}, received {1}" - exp_type = contents.ProtocolVersion.ProtocolVersionMajor + exp_type = int rcv_type = type(protocol_version_major) self.assertIsInstance(protocol_version_major, exp_type, msg.format(exp_type, rcv_type)) msg = "Bad protocol version major value: expected {0}, received {1}" - self.assertEqual(1, protocol_version_major.value, - msg.format(1, protocol_version_major.value)) + self.assertEqual(1, protocol_version_major, + msg.format(1, protocol_version_major)) - protocol_version_minor = protocol_version.protocol_version_minor + protocol_version_minor = protocol_version.minor msg = "Bad protocol version minor type: expected {0}, received {1}" - exp_type = contents.ProtocolVersion.ProtocolVersionMinor + exp_type = int rcv_type = type(protocol_version_minor) self.assertIsInstance(protocol_version_minor, exp_type, msg.format(exp_type, rcv_type)) msg = "Bad protocol version minor value: expected {0}, received {1}" - self.assertEqual(1, protocol_version_minor.value, - msg.format(1, protocol_version_minor.value)) + self.assertEqual(1, protocol_version_minor, + msg.format(1, protocol_version_minor)) batch_count = request_header.batch_count msg = "Bad batch count type: expected {0}, received {1}" @@ -775,7 +775,7 @@ class TestRequestMessage(TestCase): msg.format(exp_length, rcv_length)) def test_register_request_write(self): - prot_ver = contents.ProtocolVersion.create(1, 1) + prot_ver = contents.ProtocolVersion(1, 1) batch_count = contents.BatchCount(1) req_header = messages.RequestHeader(protocol_version=prot_ver, @@ -867,25 +867,25 @@ class TestRequestMessage(TestCase): msg.format(contents.ProtocolVersion, type(protocol_version))) - protocol_version_major = protocol_version.protocol_version_major + protocol_version_major = protocol_version.major msg = "Bad protocol version major type: expected {0}, received {1}" - exp_type = contents.ProtocolVersion.ProtocolVersionMajor + exp_type = int rcv_type = type(protocol_version_major) self.assertIsInstance(protocol_version_major, exp_type, msg.format(exp_type, rcv_type)) msg = "Bad protocol version major value: expected {0}, received {1}" - self.assertEqual(1, protocol_version_major.value, - msg.format(1, protocol_version_major.value)) + self.assertEqual(1, protocol_version_major, + msg.format(1, protocol_version_major)) - protocol_version_minor = protocol_version.protocol_version_minor + protocol_version_minor = protocol_version.minor msg = "Bad protocol version minor type: expected {0}, received {1}" - exp_type = contents.ProtocolVersion.ProtocolVersionMinor + exp_type = int rcv_type = type(protocol_version_minor) self.assertIsInstance(protocol_version_minor, exp_type, msg.format(exp_type, rcv_type)) msg = "Bad protocol version minor value: expected {0}, received {1}" - self.assertEqual(1, protocol_version_minor.value, - msg.format(1, protocol_version_minor.value)) + self.assertEqual(1, protocol_version_minor, + msg.format(1, protocol_version_minor)) batch_count = request_header.batch_count msg = "Bad batch count type: expected {0}, received {1}" @@ -1012,25 +1012,25 @@ class TestRequestMessage(TestCase): msg.format(contents.ProtocolVersion, type(protocol_version))) - protocol_version_major = protocol_version.protocol_version_major + protocol_version_major = protocol_version.major msg = "Bad protocol version major type: expected {0}, received {1}" - exp_type = contents.ProtocolVersion.ProtocolVersionMajor + exp_type = int rcv_type = type(protocol_version_major) self.assertIsInstance(protocol_version_major, exp_type, msg.format(exp_type, rcv_type)) msg = "Bad protocol version major value: expected {0}, received {1}" - self.assertEqual(1, protocol_version_major.value, - msg.format(1, protocol_version_major.value)) + self.assertEqual(1, protocol_version_major, + msg.format(1, protocol_version_major)) - protocol_version_minor = protocol_version.protocol_version_minor + protocol_version_minor = protocol_version.minor msg = "Bad protocol version minor type: expected {0}, received {1}" - exp_type = contents.ProtocolVersion.ProtocolVersionMinor + exp_type = int rcv_type = type(protocol_version_minor) self.assertIsInstance(protocol_version_minor, exp_type, msg.format(exp_type, rcv_type)) msg = "Bad protocol version minor value: expected {0}, received {1}" - self.assertEqual(2, protocol_version_minor.value, - msg.format(2, protocol_version_minor.value)) + self.assertEqual(2, protocol_version_minor, + msg.format(2, protocol_version_minor)) batch_count = request_header.batch_count msg = "Bad batch count type: expected {0}, received {1}" @@ -1114,7 +1114,7 @@ class TestRequestMessage(TestCase): ) def test_mac_request_write(self): - prot_ver = contents.ProtocolVersion.create(1, 2) + prot_ver = contents.ProtocolVersion(1, 2) batch_count = contents.BatchCount(1) req_header = messages.RequestHeader(protocol_version=prot_ver, @@ -1284,26 +1284,26 @@ class TestResponseMessage(TestCase): contents.ProtocolVersion, type(protocol_version))) - protocol_version_major = protocol_version.protocol_version_major - exp_type = contents.ProtocolVersion.ProtocolVersionMajor + protocol_version_major = protocol_version.major + exp_type = int rcv_type = type(protocol_version_major) self.assertIsInstance(protocol_version_major, exp_type, self.msg.format('protocol version major', 'type', exp_type, rcv_type)) - self.assertEqual(1, protocol_version_major.value, + self.assertEqual(1, protocol_version_major, self.msg.format('protocol version major', 'value', - 1, protocol_version_major.value)) + 1, protocol_version_major)) - protocol_version_minor = protocol_version.protocol_version_minor - exp_type = contents.ProtocolVersion.ProtocolVersionMinor + protocol_version_minor = protocol_version.minor + exp_type = int rcv_type = type(protocol_version_minor) self.assertIsInstance(protocol_version_minor, - contents.ProtocolVersion.ProtocolVersionMinor, + int, self.msg.format('protocol version minor', 'type', exp_type, rcv_type)) - self.assertEqual(1, protocol_version_minor.value, + self.assertEqual(1, protocol_version_minor, self.msg.format('protocol version minor', 'value', - 1, protocol_version_minor.value)) + 1, protocol_version_minor)) time_stamp = response_header.time_stamp value = 0x4f9a54e5 # Fri Apr 27 10:12:21 CEST 2012 @@ -1383,7 +1383,7 @@ class TestResponseMessage(TestCase): unique_identifier.value, value)) def test_create_response_write(self): - prot_ver = contents.ProtocolVersion.create(1, 1) + prot_ver = contents.ProtocolVersion(1, 1) # Fri Apr 27 10:12:21 CEST 2012 time_stamp = contents.TimeStamp(0x4f9a54e5) @@ -1436,25 +1436,25 @@ class TestResponseMessage(TestCase): contents.ProtocolVersion, type(protocol_version))) - protocol_version_major = protocol_version.protocol_version_major - exp_type = contents.ProtocolVersion.ProtocolVersionMajor + protocol_version_major = protocol_version.major + exp_type = int rcv_type = type(protocol_version_major) self.assertIsInstance(protocol_version_major, exp_type, self.msg.format('protocol version major', 'type', exp_type, rcv_type)) - self.assertEqual(1, protocol_version_major.value, + self.assertEqual(1, protocol_version_major, self.msg.format('protocol version major', 'value', - 1, protocol_version_major.value)) + 1, protocol_version_major)) - protocol_version_minor = protocol_version.protocol_version_minor - exp_type = contents.ProtocolVersion.ProtocolVersionMinor + protocol_version_minor = protocol_version.minor + exp_type = int rcv_type = type(protocol_version_minor) self.assertIsInstance(protocol_version_minor, exp_type, self.msg.format('protocol version minor', 'type', exp_type, rcv_type)) - self.assertEqual(1, protocol_version_minor.value, + self.assertEqual(1, protocol_version_minor, self.msg.format('protocol version minor', 'value', - 1, protocol_version_minor.value)) + 1, protocol_version_minor)) time_stamp = response_header.time_stamp value = 0x4f9a54e7 # Fri Apr 27 10:12:23 CEST 2012 @@ -1584,7 +1584,7 @@ class TestResponseMessage(TestCase): 'value', exp, obs)) def test_get_response_write(self): - prot_ver = contents.ProtocolVersion.create(1, 1) + prot_ver = contents.ProtocolVersion(1, 1) # Fri Apr 27 10:12:23 CEST 2012 time_stamp = contents.TimeStamp(0x4f9a54e7) @@ -1665,25 +1665,25 @@ class TestResponseMessage(TestCase): msg.format(contents.ProtocolVersion, type(protocol_version))) - protocol_version_major = protocol_version.protocol_version_major + protocol_version_major = protocol_version.major msg = "Bad protocol version major type: expected {0}, received {1}" - exp_type = contents.ProtocolVersion.ProtocolVersionMajor + exp_type = int rcv_type = type(protocol_version_major) self.assertIsInstance(protocol_version_major, exp_type, msg.format(exp_type, rcv_type)) msg = "Bad protocol version major value: expected {0}, received {1}" - self.assertEqual(1, protocol_version_major.value, - msg.format(1, protocol_version_major.value)) + self.assertEqual(1, protocol_version_major, + msg.format(1, protocol_version_major)) - protocol_version_minor = protocol_version.protocol_version_minor + protocol_version_minor = protocol_version.minor msg = "Bad protocol version minor type: expected {0}, received {1}" - exp_type = contents.ProtocolVersion.ProtocolVersionMinor + exp_type = int rcv_type = type(protocol_version_minor) self.assertIsInstance(protocol_version_minor, exp_type, msg.format(exp_type, rcv_type)) msg = "Bad protocol version minor value: expected {0}, received {1}" - self.assertEqual(1, protocol_version_minor.value, - msg.format(1, protocol_version_minor.value)) + self.assertEqual(1, protocol_version_minor, + msg.format(1, protocol_version_minor)) time_stamp = response_header.time_stamp value = 0x4f9a54e5 # Fri Apr 27 10:12:21 CEST 2012 @@ -1758,7 +1758,7 @@ class TestResponseMessage(TestCase): msg.format(exp_value, rcv_value)) def test_destroy_response_write(self): - prot_ver = contents.ProtocolVersion.create(1, 1) + prot_ver = contents.ProtocolVersion(1, 1) # Fri Apr 27 10:12:21 CEST 2012 time_stamp = contents.TimeStamp(0x4f9a54e5) @@ -1808,25 +1808,25 @@ class TestResponseMessage(TestCase): msg.format(contents.ProtocolVersion, type(protocol_version))) - protocol_version_major = protocol_version.protocol_version_major + protocol_version_major = protocol_version.major msg = "Bad protocol version major type: expected {0}, received {1}" - exp_type = contents.ProtocolVersion.ProtocolVersionMajor + exp_type = int rcv_type = type(protocol_version_major) self.assertIsInstance(protocol_version_major, exp_type, msg.format(exp_type, rcv_type)) msg = "Bad protocol version major value: expected {0}, received {1}" - self.assertEqual(1, protocol_version_major.value, - msg.format(1, protocol_version_major.value)) + self.assertEqual(1, protocol_version_major, + msg.format(1, protocol_version_major)) - protocol_version_minor = protocol_version.protocol_version_minor + protocol_version_minor = protocol_version.minor msg = "Bad protocol version minor type: expected {0}, received {1}" - exp_type = contents.ProtocolVersion.ProtocolVersionMinor + exp_type = int rcv_type = type(protocol_version_minor) self.assertIsInstance(protocol_version_minor, exp_type, msg.format(exp_type, rcv_type)) msg = "Bad protocol version minor value: expected {0}, received {1}" - self.assertEqual(1, protocol_version_minor.value, - msg.format(1, protocol_version_minor.value)) + self.assertEqual(1, protocol_version_minor, + msg.format(1, protocol_version_minor)) time_stamp = response_header.time_stamp value = 0x4f9a54e5 # Fri Apr 27 10:12:21 CEST 2012 @@ -1901,7 +1901,7 @@ class TestResponseMessage(TestCase): msg.format(exp_value, rcv_value)) def test_register_response_write(self): - prot_ver = contents.ProtocolVersion.create(1, 1) + prot_ver = contents.ProtocolVersion(1, 1) # Fri Apr 27 10:12:21 CEST 2012 time_stamp = contents.TimeStamp(0x4f9a54e5) @@ -1934,7 +1934,7 @@ class TestResponseMessage(TestCase): self.assertEqual(self.register, result, msg) def test_locate_response_write(self): - prot_ver = contents.ProtocolVersion.create(1, 1) + prot_ver = contents.ProtocolVersion(1, 1) # Fri Apr 27 10:12:22 CEST 2012 time_stamp = contents.TimeStamp(0x4f9a54e6) @@ -1985,25 +1985,25 @@ class TestResponseMessage(TestCase): contents.ProtocolVersion, type(protocol_version))) - protocol_version_major = protocol_version.protocol_version_major - exp_type = contents.ProtocolVersion.ProtocolVersionMajor + protocol_version_major = protocol_version.major + exp_type = int rcv_type = type(protocol_version_major) self.assertIsInstance(protocol_version_major, exp_type, self.msg.format('protocol version major', 'type', exp_type, rcv_type)) - self.assertEqual(1, protocol_version_major.value, + self.assertEqual(1, protocol_version_major, self.msg.format('protocol version major', 'value', - 1, protocol_version_major.value)) + 1, protocol_version_major)) - protocol_version_minor = protocol_version.protocol_version_minor - exp_type = contents.ProtocolVersion.ProtocolVersionMinor + protocol_version_minor = protocol_version.minor + exp_type = int rcv_type = type(protocol_version_minor) self.assertIsInstance(protocol_version_minor, exp_type, self.msg.format('protocol version minor', 'type', exp_type, rcv_type)) - self.assertEqual(2, protocol_version_minor.value, + self.assertEqual(2, protocol_version_minor, self.msg.format('protocol version minor', 'value', - 2, protocol_version_minor.value)) + 2, protocol_version_minor)) time_stamp = response_header.time_stamp value = 0x588a3f23 @@ -2092,7 +2092,7 @@ class TestResponseMessage(TestCase): binascii.hexlify(value))) def test_mac_response_write(self): - prot_ver = contents.ProtocolVersion.create(1, 2) + prot_ver = contents.ProtocolVersion(1, 2) # Fri Apr 27 10:12:23 CEST 2012 time_stamp = contents.TimeStamp(0x588a3f23) @@ -2142,7 +2142,7 @@ class TestResponseMessage(TestCase): def test_message_invalid_response_write(self): # Batch item of 'INVALID MESSAGE' response # has no 'operation' attribute - prot_ver = contents.ProtocolVersion.create(1, 1) + prot_ver = contents.ProtocolVersion(1, 1) # Time stamp Tue Mar 29 10:58:37 2016 time_stamp = contents.TimeStamp(0x56fa43bd) diff --git a/kmip/tests/unit/services/server/test_engine.py b/kmip/tests/unit/services/server/test_engine.py index c31bdcd..ded97a8 100644 --- a/kmip/tests/unit/services/server/test_engine.py +++ b/kmip/tests/unit/services/server/test_engine.py @@ -89,7 +89,7 @@ class TestKmipEngine(testtools.TestCase): ) ] - protocol = contents.ProtocolVersion.create(1, 0) + protocol = contents.ProtocolVersion(1, 0) max_size = contents.MaximumResponseSize(2 ** 20) asynch = contents.AsynchronousIndicator(False) @@ -339,7 +339,7 @@ class TestKmipEngine(testtools.TestCase): """ e = engine.KmipEngine() e._logger = mock.MagicMock() - e._protocol_version = contents.ProtocolVersion.create(1, 0) + e._protocol_version = contents.ProtocolVersion(1, 0) args = (None, ) regex = "DiscoverVersions is not supported by KMIP {0}".format( @@ -360,7 +360,7 @@ class TestKmipEngine(testtools.TestCase): e = engine.KmipEngine() e._logger = mock.MagicMock() - protocol = contents.ProtocolVersion.create(1, 1) + protocol = contents.ProtocolVersion(1, 1) header = messages.RequestHeader( protocol_version=protocol, maximum_response_size=contents.MaximumResponseSize(2 ** 20), @@ -402,7 +402,7 @@ class TestKmipEngine(testtools.TestCase): self.assertIsNotNone(header) self.assertEqual( - contents.ProtocolVersion.create(1, 1), + contents.ProtocolVersion(1, 1), header.protocol_version ) self.assertIsInstance(header.time_stamp, contents.TimeStamp) @@ -442,7 +442,7 @@ class TestKmipEngine(testtools.TestCase): e = engine.KmipEngine() e._logger = mock.MagicMock() - protocol = contents.ProtocolVersion.create(0, 1) + protocol = contents.ProtocolVersion(0, 1) header = messages.RequestHeader( protocol_version=protocol ) @@ -470,7 +470,7 @@ class TestKmipEngine(testtools.TestCase): e = engine.KmipEngine() e._logger = mock.MagicMock() - protocol = contents.ProtocolVersion.create(1, 0) + protocol = contents.ProtocolVersion(1, 0) header = messages.RequestHeader( protocol_version=protocol, time_stamp=contents.TimeStamp(0) @@ -503,7 +503,7 @@ class TestKmipEngine(testtools.TestCase): e = engine.KmipEngine() e._logger = mock.MagicMock() - protocol = contents.ProtocolVersion.create(1, 0) + protocol = contents.ProtocolVersion(1, 0) header = messages.RequestHeader( protocol_version=protocol, time_stamp=contents.TimeStamp(10 ** 10) @@ -536,7 +536,7 @@ class TestKmipEngine(testtools.TestCase): e = engine.KmipEngine() e._logger = mock.MagicMock() - protocol = contents.ProtocolVersion.create(1, 1) + protocol = contents.ProtocolVersion(1, 1) header = messages.RequestHeader( protocol_version=protocol, asynchronous_indicator=contents.AsynchronousIndicator(True) @@ -563,7 +563,7 @@ class TestKmipEngine(testtools.TestCase): e = engine.KmipEngine() e._logger = mock.MagicMock() - protocol = contents.ProtocolVersion.create(1, 1) + protocol = contents.ProtocolVersion(1, 1) header = messages.RequestHeader( protocol_version=protocol, authentication=contents.Authentication(), @@ -593,7 +593,7 @@ class TestKmipEngine(testtools.TestCase): e = engine.KmipEngine() e._logger = mock.MagicMock() - protocol = contents.ProtocolVersion.create(1, 1) + protocol = contents.ProtocolVersion(1, 1) header = messages.RequestHeader( protocol_version=protocol, authentication=None, @@ -629,7 +629,7 @@ class TestKmipEngine(testtools.TestCase): e._logger = mock.MagicMock() response = e.build_error_response( - contents.ProtocolVersion.create(1, 1), + contents.ProtocolVersion(1, 1), enums.ResultReason.GENERAL_FAILURE, "A general test failure occurred." ) @@ -639,7 +639,7 @@ class TestKmipEngine(testtools.TestCase): header = response.response_header self.assertEqual( - contents.ProtocolVersion.create(1, 1), + contents.ProtocolVersion(1, 1), header.protocol_version ) self.assertIsNotNone(header.time_stamp) @@ -760,7 +760,7 @@ class TestKmipEngine(testtools.TestCase): """ e = engine.KmipEngine() e._logger = mock.MagicMock() - e._protocol_version = contents.ProtocolVersion.create(1, 0) + e._protocol_version = contents.ProtocolVersion(1, 0) batch = list([ messages.RequestBatchItem( @@ -6515,7 +6515,7 @@ class TestKmipEngine(testtools.TestCase): e = engine.KmipEngine() e._logger = mock.MagicMock() - e._protocol_version = contents.ProtocolVersion.create(1, 0) + e._protocol_version = contents.ProtocolVersion(1, 0) payload = payloads.QueryRequestPayload([ misc.QueryFunction(enums.QueryFunction.QUERY_OPERATIONS), @@ -6601,7 +6601,7 @@ class TestKmipEngine(testtools.TestCase): e = engine.KmipEngine() e._logger = mock.MagicMock() - e._protocol_version = contents.ProtocolVersion.create(1, 1) + e._protocol_version = contents.ProtocolVersion(1, 1) payload = payloads.QueryRequestPayload([ misc.QueryFunction(enums.QueryFunction.QUERY_OPERATIONS), @@ -6691,7 +6691,7 @@ class TestKmipEngine(testtools.TestCase): e = engine.KmipEngine() e._logger = mock.MagicMock() - e._protocol_version = contents.ProtocolVersion.create(1, 2) + e._protocol_version = contents.ProtocolVersion(1, 2) payload = payloads.QueryRequestPayload([ misc.QueryFunction(enums.QueryFunction.QUERY_OPERATIONS), @@ -6817,22 +6817,22 @@ class TestKmipEngine(testtools.TestCase): self.assertIsNotNone(result.protocol_versions) self.assertEqual(3, len(result.protocol_versions)) self.assertEqual( - contents.ProtocolVersion.create(1, 2), + contents.ProtocolVersion(1, 2), result.protocol_versions[0] ) self.assertEqual( - contents.ProtocolVersion.create(1, 1), + contents.ProtocolVersion(1, 1), result.protocol_versions[1] ) self.assertEqual( - contents.ProtocolVersion.create(1, 0), + contents.ProtocolVersion(1, 0), result.protocol_versions[2] ) # Test detailed request. e._logger = mock.MagicMock() payload = payloads.DiscoverVersionsRequestPayload([ - contents.ProtocolVersion.create(1, 0) + contents.ProtocolVersion(1, 0) ]) result = e._process_discover_versions(payload) @@ -6843,14 +6843,14 @@ class TestKmipEngine(testtools.TestCase): self.assertIsNotNone(result.protocol_versions) self.assertEqual(1, len(result.protocol_versions)) self.assertEqual( - contents.ProtocolVersion.create(1, 0), + contents.ProtocolVersion(1, 0), result.protocol_versions[0] ) # Test disjoint request. e._logger = mock.MagicMock() payload = payloads.DiscoverVersionsRequestPayload([ - contents.ProtocolVersion.create(0, 1) + contents.ProtocolVersion(0, 1) ]) result = e._process_discover_versions(payload) diff --git a/kmip/tests/unit/services/server/test_policy.py b/kmip/tests/unit/services/server/test_policy.py index 42b902e..90af6ed 100644 --- a/kmip/tests/unit/services/server/test_policy.py +++ b/kmip/tests/unit/services/server/test_policy.py @@ -35,14 +35,14 @@ class TestAttributePolicy(testtools.TestCase): """ Test that an AttributePolicy can be built without any errors. """ - policy.AttributePolicy(contents.ProtocolVersion.create(1, 0)) + policy.AttributePolicy(contents.ProtocolVersion(1, 0)) def test_is_attribute_supported(self): """ Test that is_attribute_supported returns the expected results in all cases. """ - rules = policy.AttributePolicy(contents.ProtocolVersion.create(1, 0)) + rules = policy.AttributePolicy(contents.ProtocolVersion(1, 0)) attribute_a = 'Unique Identifier' attribute_b = 'Certificate Length' attribute_c = 'invalid' @@ -61,7 +61,7 @@ class TestAttributePolicy(testtools.TestCase): Test that is_attribute_deprecated returns the expected results in all cases. """ - rules = policy.AttributePolicy(contents.ProtocolVersion.create(1, 0)) + rules = policy.AttributePolicy(contents.ProtocolVersion(1, 0)) attribute_a = 'Name' attribute_b = 'Certificate Subject' @@ -71,7 +71,7 @@ class TestAttributePolicy(testtools.TestCase): result = rules.is_attribute_deprecated(attribute_b) self.assertFalse(result) - rules = policy.AttributePolicy(contents.ProtocolVersion.create(1, 1)) + rules = policy.AttributePolicy(contents.ProtocolVersion(1, 1)) result = rules.is_attribute_deprecated(attribute_b) self.assertTrue(result) @@ -81,7 +81,7 @@ class TestAttributePolicy(testtools.TestCase): Test that is_attribute_applicable_to_object_type returns the expected results in all cases. """ - rules = policy.AttributePolicy(contents.ProtocolVersion.create(1, 0)) + rules = policy.AttributePolicy(contents.ProtocolVersion(1, 0)) attribute = 'Cryptographic Algorithm' object_type_a = enums.ObjectType.SYMMETRIC_KEY object_type_b = enums.ObjectType.OPAQUE_DATA @@ -103,7 +103,7 @@ class TestAttributePolicy(testtools.TestCase): Test that is_attribute_multivalued returns the expected results in all cases. """ - rules = policy.AttributePolicy(contents.ProtocolVersion.create(1, 0)) + rules = policy.AttributePolicy(contents.ProtocolVersion(1, 0)) attribute_a = 'Object Type' attribute_b = 'Link' @@ -118,7 +118,7 @@ class TestAttributePolicy(testtools.TestCase): Test that get_all_attribute_names returns a complete list of the names of all spec-defined attributes. """ - rules = policy.AttributePolicy(contents.ProtocolVersion.create(1, 0)) + rules = policy.AttributePolicy(contents.ProtocolVersion(1, 0)) attribute_names = [ 'Unique Identifier', 'Name', diff --git a/kmip/tests/unit/services/server/test_session.py b/kmip/tests/unit/services/server/test_session.py index 794dad4..3ebc057 100644 --- a/kmip/tests/unit/services/server/test_session.py +++ b/kmip/tests/unit/services/server/test_session.py @@ -345,7 +345,7 @@ class TestKmipSession(testtools.TestCase): ) batch_items = [batch_item] header = messages.ResponseHeader( - protocol_version=contents.ProtocolVersion.create(1, 0), + protocol_version=contents.ProtocolVersion(1, 0), time_stamp=contents.TimeStamp(int(time.time())), batch_count=contents.BatchCount(len(batch_items)) ) diff --git a/kmip/tests/unit/services/test_kmip_client.py b/kmip/tests/unit/services/test_kmip_client.py index e6b94a9..36cc5e1 100644 --- a/kmip/tests/unit/services/test_kmip_client.py +++ b/kmip/tests/unit/services/test_kmip_client.py @@ -385,7 +385,7 @@ class TestKMIPClient(TestCase): self.assertEqual(protocol_versions, observed, msg) def test_build_discover_versions_batch_item_with_input(self): - protocol_versions = [ProtocolVersion.create(1, 0)] + protocol_versions = [ProtocolVersion(1, 0)] self._test_build_discover_versions_batch_item(protocol_versions) def test_build_discover_versions_batch_item_no_input(self): @@ -612,7 +612,7 @@ class TestKMIPClient(TestCase): self.assertEqual(protocol_versions, result.protocol_versions, msg) def test_process_discover_versions_batch_item_with_results(self): - protocol_versions = [ProtocolVersion.create(1, 0)] + protocol_versions = [ProtocolVersion(1, 0)] self._test_process_discover_versions_batch_item(protocol_versions) def test_process_discover_versions_batch_item_no_results(self):