diff --git a/kmip/core/exceptions.py b/kmip/core/exceptions.py index fcab164..8308d93 100644 --- a/kmip/core/exceptions.py +++ b/kmip/core/exceptions.py @@ -21,6 +21,14 @@ class InvalidKmipEncoding(Exception): pass +class InvalidPaddingBytes(Exception): + """ + An exception raised for errors when processing the padding bytes of + primitive encodings. + """ + pass + + class InvalidPrimitiveLength(Exception): """ An exception raised for errors when processing primitives with invalid diff --git a/kmip/core/primitives.py b/kmip/core/primitives.py index e5f8050..b645942 100644 --- a/kmip/core/primitives.py +++ b/kmip/core/primitives.py @@ -839,8 +839,109 @@ class DateTime(LongInteger): return time.ctime(self.value) -class Interval(Integer): +class Interval(Base): + """ + An encodeable object representing an interval of time. - def __init__(self, value=None, tag=Tags.DEFAULT): - super(Interval, self).__init__(value, tag) - self.type = Types.INTERVAL + An Interval is one of the KMIP primitive object types. It is encoded as + an unsigned, big-endian, 32-bit integer, where the value has a resolution + of one second. For more information, see Section 9.1 of the KMIP 1.1 + specification. + """ + LENGTH = 4 + + # Bounds for unsigned 32-bit integers + MIN = 0 + MAX = 4294967296 + + def __init__(self, value=0, tag=Tags.DEFAULT): + super(Interval, self).__init__(tag, type=Types.INTERVAL) + + self.value = value + self.length = Interval.LENGTH + + self.validate() + + def read(self, istream): + """ + Read the encoding of the Interval from the input stream. + + Args: + istream (stream): A buffer containing the encoded bytes of the + value of an Interval. Usually a BytearrayStream object. + Required. + + Raises: + InvalidPrimitiveLength: if the Interval encoding read in has an + invalid encoded length. + InvalidPaddingBytes: if the Interval encoding read in does not use + zeroes for its padding bytes. + """ + super(Interval, self).read(istream) + + # Check for a valid length before even trying to parse the value. + if self.length != Interval.LENGTH: + raise exceptions.InvalidPrimitiveLength( + "interval length must be {0}".format(Interval.LENGTH)) + + # Decode the Interval value and the padding bytes. + self.value = unpack('!I', istream.read(Interval.LENGTH))[0] + pad = unpack('!I', istream.read(Interval.LENGTH))[0] + + # Verify that the padding bytes are zero bytes. + if pad is not 0: + raise exceptions.InvalidPaddingBytes("padding bytes must be zero") + + self.validate() + + def write(self, ostream): + """ + Write the encoding of the Interval to the output stream. + + Args: + ostream (stream): A buffer to contain the encoded bytes of an + Interval. Usually a BytearrayStream object. Required. + """ + super(Interval, self).write(ostream) + ostream.write(pack('!I', self.value)) + ostream.write(pack('!I', 0)) + + def validate(self): + """ + Verify that the value of the Interval is valid. + + Raises: + TypeError: if the value is not of type int or long + ValueError: if the value cannot be represented by an unsigned + 32-bit integer + """ + if self.value is not None: + if type(self.value) not in six.integer_types: + raise TypeError('expected (one of): {0}, observed: {1}'.format( + six.integer_types, type(self.value))) + else: + if self.value > Interval.MAX: + raise ValueError( + 'interval value greater than accepted max') + elif self.value < Interval.MIN: + raise ValueError('interval value less than accepted min') + + def __repr__(self): + value = "value={0}".format(self.value) + tag = "tag={0}".format(self.tag) + return "Interval({0}, {1})".format(value, tag) + + def __str__(self): + return "{0}".format(self.value) + + def __eq__(self, other): + if isinstance(other, Interval): + return self.value == other.value + else: + return NotImplemented + + def __ne__(self, other): + if isinstance(other, Interval): + return not self.__eq__(other) + else: + return NotImplemented diff --git a/kmip/tests/unit/core/primitives/test_interval.py b/kmip/tests/unit/core/primitives/test_interval.py index 62ee3d8..5f8cbde 100644 --- a/kmip/tests/unit/core/primitives/test_interval.py +++ b/kmip/tests/unit/core/primitives/test_interval.py @@ -15,35 +15,211 @@ import testtools +from kmip.core import exceptions +from kmip.core import primitives from kmip.core import utils class TestInterval(testtools.TestCase): + """ + Test suite for the Interval primitive. + """ def setUp(self): super(TestInterval, self).setUp() - self.stream = utils.BytearrayStream() + + # Encoding and value based on Section 9.1.2 of the KMIP 1.1 + # specification. + self.value = 864000 + self.encoding = ( + b'\x42\x00\x00\x0A\x00\x00\x00\x04\x00\x0D\x2F\x00\x00\x00\x00' + b'\x00') + self.encoding_bad_length = ( + b'\x42\x00\x00\x0A\x00\x00\x00\x05\x00\x0D\x2F\x00\x00\x00\x00' + b'\x00') + self.encoding_bad_padding = ( + b'\x42\x00\x00\x0A\x00\x00\x00\x04\x00\x0D\x2F\x00\x00\x00\x00' + b'\xFF') def tearDown(self): super(TestInterval, self).tearDown() def test_init(self): - self.skip('') + """ + Test that an Interval can be instantiated. + """ + interval = primitives.Interval(1) + self.assertEqual(1, interval.value) def test_init_unset(self): - self.skip('') - - def test_validate_on_valid(self): - self.skip('') - - def test_validate_on_valid_unset(self): - self.skip('') + """ + Test that an Interval can be instantiated with no input. + """ + interval = primitives.Interval() + self.assertEqual(0, interval.value) def test_validate_on_invalid_type(self): - self.skip('') + """ + Test that a TypeError is thrown on input of invalid type (e.g., str). + """ + self.assertRaises(TypeError, primitives.Interval, 'invalid') + + def test_validate_on_invalid_value_too_big(self): + """ + Test that a ValueError is thrown on input that is too large. + """ + self.assertRaises( + ValueError, primitives.Interval, primitives.Interval.MAX + 1) + + def test_validate_on_invalid_value_too_small(self): + """ + Test that a ValueError is thrown on input that is too small. + """ + self.assertRaises( + ValueError, primitives.Interval, primitives.Interval.MIN - 1) def test_read(self): - self.skip('') + """ + Test that an Interval can be read from a byte stream. + """ + stream = utils.BytearrayStream(self.encoding) + interval = primitives.Interval() + interval.read(stream) + self.assertEqual(self.value, interval.value) + + def test_read_on_invalid_length(self): + """ + Test that an InvalidPrimitiveLength exception is thrown when attempting + to decode an Interval with an invalid length. + """ + stream = utils.BytearrayStream(self.encoding_bad_length) + interval = primitives.Interval() + self.assertRaises( + exceptions.InvalidPrimitiveLength, interval.read, stream) + + def test_read_on_invalid_padding(self): + """ + Test that an InvalidPaddingBytes exception is thrown when attempting + to decode an Interval with invalid padding bytes. + """ + stream = utils.BytearrayStream(self.encoding_bad_padding) + interval = primitives.Interval() + self.assertRaises( + exceptions.InvalidPaddingBytes, interval.read, stream) def test_write(self): - self.skip('') + """ + Test that an Interval can be written to a byte stream. + """ + stream = utils.BytearrayStream() + interval = primitives.Interval(self.value) + interval.write(stream) + + result = stream.read() + self.assertEqual(len(self.encoding), len(result)) + self.assertEqual(self.encoding, result) + + def test_repr(self): + """ + Test that the representation of a Interval is formatted properly. + """ + long_int = primitives.Interval() + value = "value={0}".format(long_int.value) + tag = "tag={0}".format(long_int.tag) + self.assertEqual( + "Interval({0}, {1})".format(value, tag), repr(long_int)) + + def test_str(self): + """ + Test that the string representation of a Interval is formatted + properly. + """ + self.assertEqual("0", str(primitives.Interval())) + + def test_equal_on_equal(self): + """ + Test that the equality operator returns True when comparing two + Intervals. + """ + a = primitives.Interval(1) + b = primitives.Interval(1) + + self.assertTrue(a == b) + self.assertTrue(b == a) + + def test_equal_on_equal_and_empty(self): + """ + Test that the equality operator returns True when comparing two + Intervals. + """ + a = primitives.Interval() + b = primitives.Interval() + + self.assertTrue(a == b) + self.assertTrue(b == a) + + def test_equal_on_not_equal(self): + """ + Test that the equality operator returns False when comparing two + Intervals with different values. + """ + a = primitives.Interval(1) + b = primitives.Interval(2) + + self.assertFalse(a == b) + self.assertFalse(b == a) + + def test_equal_on_type_mismatch(self): + """ + Test that the equality operator returns False when comparing a + Interval to a non-Interval object. + """ + a = primitives.Interval() + b = 'invalid' + + self.assertFalse(a == b) + self.assertFalse(b == a) + + def test_not_equal_on_equal(self): + """ + Test that the inequality operator returns False when comparing + two Intervals with the same values. + """ + a = primitives.Interval(1) + b = primitives.Interval(1) + + self.assertFalse(a != b) + self.assertFalse(b != a) + + def test_not_equal_on_equal_and_empty(self): + """ + Test that the inequality operator returns False when comparing + two Intervals. + """ + a = primitives.Interval() + b = primitives.Interval() + + self.assertFalse(a != b) + self.assertFalse(b != a) + + def test_not_equal_on_not_equal(self): + """ + Test that the inequality operator returns True when comparing two + Intervals with different values. + """ + a = primitives.Interval(1) + b = primitives.Interval(2) + + self.assertTrue(a != b) + self.assertTrue(b != a) + + def test_not_equal_on_type_mismatch(self): + """ + Test that the inequality operator returns True when comparing a + Interval to a non-Interval object. + """ + a = primitives.Interval() + b = 'invalid' + + self.assertTrue(a != b) + self.assertTrue(b != a)