diff --git a/kmip/core/exceptions.py b/kmip/core/exceptions.py index c238171..fcab164 100644 --- a/kmip/core/exceptions.py +++ b/kmip/core/exceptions.py @@ -19,3 +19,11 @@ class InvalidKmipEncoding(Exception): An exception raised when processing invalid KMIP message encodings. """ pass + + +class InvalidPrimitiveLength(Exception): + """ + An exception raised for errors when processing primitives with invalid + lengths. + """ + pass diff --git a/kmip/core/primitives.py b/kmip/core/primitives.py index adb4d3a..c8003e3 100644 --- a/kmip/core/primitives.py +++ b/kmip/core/primitives.py @@ -26,6 +26,7 @@ from kmip.core.enums import Tags from kmip.core.errors import ErrorStrings from kmip.core import errors +from kmip.core import exceptions from kmip.core import utils @@ -249,51 +250,110 @@ class Integer(Base): class LongInteger(Base): + """ + An encodeable object representing a long integer value. + + A LongInteger is one of the KMIP primitive object types. It is encoded as + a signed, big-endian, 64-bit integer. For more information, see Section + 9.1 of the KMIP 1.1 specification. + """ + LENGTH = 8 - def __init__(self, value=None, tag=Tags.DEFAULT): + # Bounds for signed 64-bit integers + MIN = -9223372036854775808 + MAX = 9223372036854775807 + + def __init__(self, value=0, tag=Tags.DEFAULT): + """ + Create a LongInteger. + + Args: + value (int): The value of the LongInteger. Optional, defaults to 0. + tag (Tags): An enumeration defining the tag of the LongInteger. + Optional, defaults to Tags.DEFAULT. + """ super(LongInteger, self).__init__(tag, type=Types.LONG_INTEGER) self.value = value - self.length = self.LENGTH + self.length = LongInteger.LENGTH self.validate() - def read_value(self, istream): - if self.length is not self.LENGTH: - raise errors.ReadValueError(LongInteger.__name__, 'length', - self.LENGTH, self.length) + def read(self, istream): + """ + Read the encoding of the LongInteger from the input stream. + + Args: + istream (stream): A buffer containing the encoded bytes of a + LongInteger. Usually a BytearrayStream object. Required. + + Raises: + InvalidPrimitiveLength: if the long integer encoding read in has + an invalid encoded length. + """ + super(LongInteger, self).read(istream) + + if self.length is not LongInteger.LENGTH: + raise exceptions.InvalidPrimitiveLength( + "invalid long integer length read; " + "expected: {0}, observed: {1}".format( + LongInteger.LENGTH, self.length)) self.value = unpack('!q', istream.read(self.length))[0] self.validate() - def read(self, istream): - super(LongInteger, self).read(istream) - self.read_value(istream) + def write(self, ostream): + """ + Write the encoding of the LongInteger to the output stream. - def write_value(self, ostream): + Args: + ostream (stream): A buffer to contain the encoded bytes of a + LongInteger. Usually a BytearrayStream object. Required. + """ + super(LongInteger, self).write(ostream) ostream.write(pack('!q', self.value)) - def write(self, ostream): - super(LongInteger, self).write(ostream) - self.write_value(ostream) - def validate(self): - self.__validate() + """ + Verify that the value of the LongInteger is valid. - def __validate(self): + Raises: + TypeError: if the value is not of type int or long + ValueError: if the value cannot be represented by a signed 64-bit + integer + """ if self.value is not None: - data_type = type(self.value) - if data_type not in six.integer_types: - raise errors.StateTypeError( - LongInteger.__name__, "{0}".format(six.integer_types), - data_type) - num_bytes = utils.count_bytes(self.value) - if num_bytes > self.length: - raise errors.StateOverflowError( - LongInteger.__name__, 'value', self.length, num_bytes) + if not isinstance(self.value, six.integer_types): + raise TypeError('expected (one of): {0}, observed: {1}'.format( + six.integer_types, type(self.value))) + else: + if self.value > LongInteger.MAX: + raise ValueError( + 'long integer value greater than accepted max') + elif self.value < LongInteger.MIN: + raise ValueError( + 'long integer value less than accepted min') def __repr__(self): - return '' % (self.value) + return "LongInteger(value={0}, tag={1})".format(self.value, self.tag) + + def __str__(self): + return str(self.value) + + def __eq__(self, other): + if isinstance(other, LongInteger): + if self.value == other.value: + return True + else: + return False + else: + return NotImplemented + + def __ne__(self, other): + if isinstance(other, LongInteger): + return not self.__eq__(other) + else: + return NotImplemented class BigInteger(Base): diff --git a/kmip/tests/unit/core/primitives/test_long_integer.py b/kmip/tests/unit/core/primitives/test_long_integer.py index 420d7bf..95e8785 100644 --- a/kmip/tests/unit/core/primitives/test_long_integer.py +++ b/kmip/tests/unit/core/primitives/test_long_integer.py @@ -15,195 +15,344 @@ import testtools -from kmip.core import errors +from kmip.core import exceptions from kmip.core import primitives from kmip.core import utils class TestLongInteger(testtools.TestCase): + """ + Test suite for the LongInteger primitive. + """ def setUp(self): super(TestLongInteger, self).setUp() - self.stream = utils.BytearrayStream() - self.max_byte_long = 18446744073709551615 - self.max_long = 9223372036854775807 - self.bad_value = ( - 'Bad primitives.LongInteger.{0} after init: expected {1}, ' - 'received {2}') - self.bad_write = ( - 'Bad primitives.LongInteger write: expected {0} bytes, ' - 'received {1} bytes') - self.bad_encoding = ( - 'Bad primitives.LongInteger write: encoding mismatch') - self.bad_read = ( - 'Bad primitives.LongInteger.value read: expected {0}, ' - 'received {1}') def tearDown(self): super(TestLongInteger, self).tearDown() def test_init(self): - i = primitives.LongInteger(0) - - self.assertEqual(0, i.value, - self.bad_value.format('value', 0, i.value)) - self.assertEqual(i.LENGTH, i.length, - self.bad_value.format('length', i.LENGTH, i.length)) + """ + Test that a LongInteger can be instantiated. + """ + long_int = primitives.LongInteger(1) + self.assertEqual(1, long_int.value) def test_init_unset(self): - i = primitives.LongInteger() + """ + Test that a LongInteger can be instantiated with no input. + """ + long_int = primitives.LongInteger() + self.assertEqual(0, long_int.value) - self.assertEqual(None, i.value, - self.bad_value.format('value', None, i.value)) - self.assertEqual(i.LENGTH, i.length, - self.bad_value.format('length', i.LENGTH, i.length)) + def test_init_on_max(self): + """ + Test that a LongInteger can be instantiated with the maximum possible + signed 64-bit value. + """ + primitives.LongInteger(primitives.LongInteger.MAX) + + def test_init_on_min(self): + """ + Test that a LongInteger can be instantiated with the minimum possible + signed 64-bit value. + """ + primitives.LongInteger(primitives.LongInteger.MIN) def test_validate_on_valid(self): - i = primitives.LongInteger() - i.value = 0 - - # Check no exception thrown - i.validate() - - def test_validate_on_valid_long(self): - i = primitives.LongInteger() - i.value = self.max_long + 1 - - # Check no exception thrown - i.validate() + """ + Test that a LongInteger can be validated on good input. + """ + long_int = primitives.LongInteger(1) + long_int.validate() def test_validate_on_valid_unset(self): - i = primitives.LongInteger() - - # Check no exception thrown - i.validate() + """ + Test that a LongInteger with no preset value can be validated. + """ + long_int = primitives.LongInteger() + long_int.validate() def test_validate_on_invalid_type(self): - i = primitives.LongInteger() - i.value = 'test' + """ + Test that a TypeError is thrown on input of invalid type (e.g., str). + """ + self.assertRaises(TypeError, primitives.LongInteger, 'invalid') - self.assertRaises(errors.StateTypeError, i.validate) + def test_validate_on_invalid_value_too_big(self): + """ + Test that a ValueError is thrown on input that is too large. + """ + self.assertRaises( + ValueError, primitives.LongInteger, primitives.LongInteger.MAX + 1) - def test_validate_on_invalid_value(self): - self.assertRaises(errors.StateOverflowError, primitives.LongInteger, - self.max_byte_long + 1) + def test_validate_on_invalid_value_too_small(self): + """ + Test that a ValueError is thrown on input that is too small. + """ + self.assertRaises( + ValueError, primitives.LongInteger, primitives.LongInteger.MIN - 1) - def test_read_value(self): - encoding = (b'\x00\x00\x00\x00\x00\x00\x00\x01') - self.stream = utils.BytearrayStream(encoding) - i = primitives.LongInteger() - i.read_value(self.stream) + def test_read_zero(self): + """ + Test that a LongInteger representing the value 0 can be read from a + byte stream. + """ + encoding = ( + b'\x42\x00\x00\x03\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00' + b'\x00') + stream = utils.BytearrayStream(encoding) + long_int = primitives.LongInteger() + long_int.read(stream) + self.assertEqual(0, long_int.value) - self.assertEqual(1, i.value, self.bad_read.format(1, i.value)) + def test_read_max_max(self): + """ + Test that a LongInteger representing the maximum positive value can be + read from a byte stream. + """ + encoding = ( + b'\x42\x00\x00\x03\x00\x00\x00\x08\x7f\xff\xff\xff\xff\xff\xff' + b'\xff') + stream = utils.BytearrayStream(encoding) + long_int = primitives.LongInteger() + long_int.read(stream) + self.assertEqual(primitives.LongInteger.MAX, long_int.value) - def test_read_value_zero(self): - encoding = (b'\x00\x00\x00\x00\x00\x00\x00\x00') - self.stream = utils.BytearrayStream(encoding) - i = primitives.LongInteger() - i.read_value(self.stream) - - self.assertEqual(0, i.value, self.bad_read.format(0, i.value)) - - def test_read_value_max_positive(self): - encoding = (b'\x7f\xff\xff\xff\xff\xff\xff\xff') - self.stream = utils.BytearrayStream(encoding) - i = primitives.LongInteger() - i.read_value(self.stream) - - self.assertEqual(self.max_long, i.value, - self.bad_read.format(1, i.value)) - - def test_read_value_min_negative(self): - encoding = (b'\xff\xff\xff\xff\xff\xff\xff\xff') - self.stream = utils.BytearrayStream(encoding) - i = primitives.LongInteger() - i.read_value(self.stream) - - self.assertEqual(-1, i.value, - self.bad_read.format(1, i.value)) - - def test_read(self): + def test_read_min_max(self): + """ + Test that a LongInteger representing the minimum positive value can be + read from a byte stream. + """ encoding = ( b'\x42\x00\x00\x03\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00' b'\x01') - self.stream = utils.BytearrayStream(encoding) - i = primitives.LongInteger() - i.read(self.stream) + stream = utils.BytearrayStream(encoding) + long_int = primitives.LongInteger() + long_int.read(stream) + self.assertEqual(1, long_int.value) - self.assertEqual(1, i.value, self.bad_read.format(1, i.value)) + def test_read_max_min(self): + """ + Test that a LongInteger representing the maximum negative value can be + read from a byte stream. + """ + encoding = ( + b'\x42\x00\x00\x03\x00\x00\x00\x08\xff\xff\xff\xff\xff\xff\xff' + b'\xff') + stream = utils.BytearrayStream(encoding) + long_int = primitives.LongInteger() + long_int.read(stream) + self.assertEqual(-1, long_int.value) + + def test_read_min_min(self): + """ + Test that a LongInteger representing the minimum negative value can be + read from a byte stream. + """ + encoding = ( + b'\x42\x00\x00\x03\x00\x00\x00\x08\x80\x00\x00\x00\x00\x00\x00' + b'\x00') + stream = utils.BytearrayStream(encoding) + long_int = primitives.LongInteger(primitives.LongInteger.MIN) + long_int.read(stream) + self.assertEqual(primitives.LongInteger.MIN, long_int.value) def test_read_on_invalid_length(self): + """ + Test that an InvalidPrimitiveLength exception is thrown when attempting + to decode a LongInteger with an invalid length. + """ encoding = ( b'\x42\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' b'\x00') - self.stream = utils.BytearrayStream(encoding) - i = primitives.LongInteger() + stream = utils.BytearrayStream(encoding) + long_int = primitives.LongInteger() - self.assertRaises(errors.ReadValueError, i.read, self.stream) + self.assertRaises( + exceptions.InvalidPrimitiveLength, long_int.read, stream) - def test_write_value(self): - encoding = (b'\x00\x00\x00\x00\x00\x00\x00\x01') - i = primitives.LongInteger(1) - i.write_value(self.stream) + def test_write_zero(self): + """ + Test that a LongInteger representing the value 0 can be written to a + byte stream. + """ + encoding = ( + b'\x42\x00\x00\x03\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00' + b'\x00') + stream = utils.BytearrayStream() + long_int = primitives.LongInteger(0) + long_int.write(stream) - result = self.stream.read() - len_exp = len(encoding) - len_rcv = len(result) + result = stream.read() + self.assertEqual(len(encoding), len(result)) + self.assertEqual(encoding, result) - self.assertEqual(len_exp, len_rcv, self.bad_write.format(len_exp, - len_rcv)) - self.assertEqual(encoding, result, self.bad_encoding) + def test_write_max_max(self): + """ + Test that a LongInteger representing the maximum positive value can be + written to a byte stream. + """ + encoding = ( + b'\x42\x00\x00\x03\x00\x00\x00\x08\x7f\xff\xff\xff\xff\xff\xff' + b'\xff') + stream = utils.BytearrayStream() + long_int = primitives.LongInteger(primitives.LongInteger.MAX) + long_int.write(stream) - def test_write_value_zero(self): - encoding = (b'\x00\x00\x00\x00\x00\x00\x00\x00') - i = primitives.LongInteger(0) - i.write_value(self.stream) + result = stream.read() + self.assertEqual(len(encoding), len(result)) + self.assertEqual(encoding, result) - result = self.stream.read() - len_exp = len(encoding) - len_rcv = len(result) - - self.assertEqual(len_exp, len_rcv, self.bad_write.format(len_exp, - len_rcv)) - self.assertEqual(encoding, result, self.bad_encoding) - - def test_write_value_max_positive(self): - encoding = (b'\x7f\xff\xff\xff\xff\xff\xff\xff') - i = primitives.LongInteger(self.max_long) - i.write_value(self.stream) - - result = self.stream.read() - len_exp = len(encoding) - len_rcv = len(result) - - self.assertEqual(len_exp, len_rcv, self.bad_write.format(len_exp, - len_rcv)) - self.assertEqual(encoding, result, self.bad_encoding) - - def test_write_value_min_negative(self): - encoding = (b'\xff\xff\xff\xff\xff\xff\xff\xff') - i = primitives.LongInteger(-1) - i.write_value(self.stream) - - result = self.stream.read() - len_exp = len(encoding) - len_rcv = len(result) - - self.assertEqual(len_exp, len_rcv, self.bad_write.format(len_exp, - len_rcv)) - self.assertEqual(encoding, result, self.bad_encoding) - - def test_write(self): + def test_write_min_max(self): + """ + Test that a LongInteger representing the minimum positive value can be + written to a byte stream. + """ encoding = ( b'\x42\x00\x00\x03\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00' b'\x01') - i = primitives.LongInteger(1) - i.write(self.stream) + stream = utils.BytearrayStream() + long_int = primitives.LongInteger(1) + long_int.write(stream) - result = self.stream.read() - len_exp = len(encoding) - len_rcv = len(result) + result = stream.read() + self.assertEqual(len(encoding), len(result)) + self.assertEqual(encoding, result) - self.assertEqual(len_exp, len_rcv, self.bad_write.format(len_exp, - len_rcv)) - self.assertEqual(encoding, result, self.bad_encoding) + def test_write_max_min(self): + """ + Test that a LongInteger representing the maximum negative value can be + written to a byte stream. + """ + encoding = ( + b'\x42\x00\x00\x03\x00\x00\x00\x08\xff\xff\xff\xff\xff\xff\xff' + b'\xff') + stream = utils.BytearrayStream() + long_int = primitives.LongInteger(-1) + long_int.write(stream) + + result = stream.read() + self.assertEqual(len(encoding), len(result)) + self.assertEqual(encoding, result) + + def test_write_min_min(self): + """ + Test that a LongInteger representing the minimum negative value can be + written to a byte stream. + """ + encoding = ( + b'\x42\x00\x00\x03\x00\x00\x00\x08\x80\x00\x00\x00\x00\x00\x00' + b'\x00') + stream = utils.BytearrayStream() + long_int = primitives.LongInteger(primitives.LongInteger.MIN) + long_int.write(stream) + + result = stream.read() + self.assertEqual(len(encoding), len(result)) + self.assertEqual(encoding, result) + + def test_repr(self): + """ + Test that the representation of a LongInteger is formatted properly. + """ + long_int = primitives.LongInteger() + value = "value={0}".format(long_int.value) + tag = "tag={0}".format(long_int.tag) + self.assertEqual( + "LongInteger({0}, {1})".format(value, tag), repr(long_int)) + + def test_str(self): + """ + Test that the string representation of a LongInteger is formatted + properly. + """ + self.assertEqual("0", str(primitives.LongInteger())) + + def test_equal_on_equal(self): + """ + Test that the equality operator returns True when comparing two + LongIntegers. + """ + a = primitives.LongInteger(1) + b = primitives.LongInteger(1) + + self.assertTrue(a == b) + self.assertTrue(b == a) + + def test_equal_on_equal_and_empty(self): + """ + Test that the equality operator returns True when comparing two + LongIntegers. + """ + a = primitives.LongInteger() + b = primitives.LongInteger() + + self.assertTrue(a == b) + self.assertTrue(b == a) + + def test_equal_on_not_equal(self): + """ + Test that the equality operator returns False when comparing two + LongIntegers with different values. + """ + a = primitives.LongInteger(1) + b = primitives.LongInteger(2) + + self.assertFalse(a == b) + self.assertFalse(b == a) + + def test_equal_on_type_mismatch(self): + """ + Test that the equality operator returns False when comparing a + LongInteger to a non-LongInteger object. + """ + a = primitives.LongInteger() + b = 'invalid' + + self.assertFalse(a == b) + self.assertFalse(b == a) + + def test_not_equal_on_equal(self): + """ + Test that the inequality operator returns False when comparing + two LongIntegers with the same values. + """ + a = primitives.LongInteger(1) + b = primitives.LongInteger(1) + + self.assertFalse(a != b) + self.assertFalse(b != a) + + def test_not_equal_on_equal_and_empty(self): + """ + Test that the inequality operator returns False when comparing + two LongIntegers. + """ + a = primitives.LongInteger() + b = primitives.LongInteger() + + self.assertFalse(a != b) + self.assertFalse(b != a) + + def test_not_equal_on_not_equal(self): + """ + Test that the inequality operator returns True when comparing two + LongIntegers with different values. + """ + a = primitives.LongInteger(1) + b = primitives.LongInteger(2) + + self.assertTrue(a != b) + self.assertTrue(b != a) + + def test_not_equal_on_type_mismatch(self): + """ + Test that the inequality operator returns True when comparing a + LongInteger to a non-LongInteger object. + """ + a = primitives.LongInteger() + b = 'invalid' + + self.assertTrue(a != b) + self.assertTrue(b != a)