diff --git a/kmip/pie/objects.py b/kmip/pie/objects.py index 8bbe780..6ce62b0 100644 --- a/kmip/pie/objects.py +++ b/kmip/pie/objects.py @@ -48,6 +48,28 @@ app_specific_info_map = sqlalchemy.Table( ) +object_group_map = sqlalchemy.Table( + "object_group_map", + sql.Base.metadata, + sqlalchemy.Column( + "managed_object_id", + sqlalchemy.Integer, + sqlalchemy.ForeignKey( + "managed_objects.uid", + ondelete="CASCADE" + ) + ), + sqlalchemy.Column( + "object_group_id", + sqlalchemy.Integer, + sqlalchemy.ForeignKey( + "object_groups.id", + ondelete="CASCADE" + ) + ) +) + + class ManagedObject(sql.Base): """ The abstract base class of the simplified KMIP object hierarchy. @@ -92,6 +114,12 @@ class ManagedObject(sql.Base): back_populates="managed_objects", passive_deletes=True ) + object_groups = sqlalchemy.orm.relationship( + "ObjectGroup", + secondary=object_group_map, + back_populates="managed_objects", + passive_deletes=True + ) __mapper_args__ = { 'polymorphic_identity': 'ManagedObject', @@ -1841,3 +1869,64 @@ class ApplicationSpecificInformation(sql.Base): return not (self == other) else: return NotImplemented + + +class ObjectGroup(sql.Base): + __tablename__ = "object_groups" + id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) + _object_group = sqlalchemy.Column( + "object_group", + sqlalchemy.String, + nullable=False, + unique=True + ) + managed_objects = sqlalchemy.orm.relationship( + "ManagedObject", + secondary=object_group_map, + back_populates="object_groups" + ) + + def __init__(self, object_group=None): + """ + Create an ObjectGroup attribute. + + Args: + object_group (str): A string specifying the object group. Required. + """ + super(ObjectGroup, self).__init__() + + self.object_group = object_group + + @property + def object_group(self): + return self._object_group + + @object_group.setter + def object_group(self, value): + if (value is None) or (isinstance(value, six.string_types)): + self._object_group = value + else: + raise TypeError("The object group must be a string.") + + def __repr__(self): + object_group = "object_group={}".format(self.object_group) + + return "ObjectGroup({})".format(object_group) + + def __str__(self): + return str({"object_group": self.object_group}) + + def __eq__(self, other): + if isinstance(other, ObjectGroup): + if self.object_group != other.object_group: + return False + else: + return True + else: + return NotImplemented + + def __ne__(self, other): + if isinstance(other, ObjectGroup): + return not (self == other) + else: + return NotImplemented diff --git a/kmip/tests/unit/pie/objects/test_object_group.py b/kmip/tests/unit/pie/objects/test_object_group.py new file mode 100644 index 0000000..a5efab2 --- /dev/null +++ b/kmip/tests/unit/pie/objects/test_object_group.py @@ -0,0 +1,185 @@ +# Copyright (c) 2019 The Johns Hopkins University/Applied Physics Laboratory +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import sqlalchemy +import testtools + +from kmip.pie import objects +from kmip.pie import sqltypes + + +class TestObjectGroup(testtools.TestCase): + """ + Test suite for ObjectGroup. + """ + + def setUp(self): + super(TestObjectGroup, self).setUp() + + def tearDown(self): + super(TestObjectGroup, self).tearDown() + + def test_init(self): + """ + Test that an ObjectGroup object can be instantiated. + """ + object_group = objects.ObjectGroup() + + self.assertIsNone(object_group.object_group) + + def test_invalid_object_group(self): + """ + Test that a TypeError is raised when an invalid object group value + is used to construct an ObjectGroup attribute. + """ + kwargs = {"object_group": []} + self.assertRaisesRegex( + TypeError, + "The object group must be a string.", + objects.ObjectGroup, + **kwargs + ) + + args = ( + objects.ObjectGroup(), + "object_group", + [] + ) + self.assertRaisesRegex( + TypeError, + "The object group must be a string.", + setattr, + *args + ) + + def test_repr(self): + """ + Test that repr can be applied to an ObjectGroup attribute. + """ + object_group = objects.ObjectGroup(object_group="Group1") + + expected = "ObjectGroup({})".format("object_group={}".format("Group1")) + observed = repr(object_group) + + self.assertEqual(expected, observed) + + def test_str(self): + """ + Test that str can be applied to an ObjectGroup attribute. + """ + object_group = objects.ObjectGroup(object_group="Group1") + + expected = str( + { + "object_group": "Group1" + } + ) + observed = str(object_group) + + self.assertEqual(expected, observed) + + def test_comparison_on_equal(self): + """ + Test that the equality/inequality operators return True/False when + comparing two ObjectGroup attributes with the same + data. + """ + a = objects.ObjectGroup() + b = objects.ObjectGroup() + + self.assertTrue(a == b) + self.assertTrue(b == a) + self.assertFalse(a != b) + self.assertFalse(b != a) + + a = objects.ObjectGroup(object_group="Group1") + b = objects.ObjectGroup(object_group="Group1") + + self.assertTrue(a == b) + self.assertTrue(b == a) + self.assertFalse(a != b) + self.assertFalse(b != a) + + def test_comparison_on_different_object_groups(self): + """ + Test that the equality/inequality operators return False/True when + comparing two ObjectGroup attributes with different object groups. + """ + a = objects.ObjectGroup(object_group="a") + b = objects.ObjectGroup(object_group="b") + + self.assertFalse(a == b) + self.assertFalse(b == a) + self.assertTrue(a != b) + self.assertTrue(b != a) + + def test_comparison_on_type_mismatch(self): + """ + Test that the equality/inequality operators return False/True when + comparing an ObjectGroup attribute to a non ObjectGroup attribute. + """ + a = objects.ObjectGroup() + b = "invalid" + + self.assertFalse(a == b) + self.assertFalse(b == a) + self.assertTrue(a != b) + self.assertTrue(b != a) + + def test_save(self): + """ + Test that an ObjectGroup attribute can be saved using SQLAlchemy. This + test will add an attribute instance to the database, verify that no + exceptions are thrown, and check that its ID was set. + """ + object_group = objects.ObjectGroup(object_group="Group1") + + engine = sqlalchemy.create_engine("sqlite:///:memory:", echo=True) + sqltypes.Base.metadata.create_all(engine) + + session = sqlalchemy.orm.sessionmaker(bind=engine)() + session.add(object_group) + session.commit() + + self.assertIsNotNone(object_group.id) + + def test_get(self): + """ + Test that an ObjectGroup attribute can be saved and then retrieved + using SQLAlchemy. This test adds the attribute to the database and then + retrieves it by ID and verifies its values. + """ + object_group = objects.ObjectGroup(object_group="Group1") + + engine = sqlalchemy.create_engine("sqlite:///:memory:", echo=True) + sqltypes.Base.metadata.create_all(engine) + + session = sqlalchemy.orm.sessionmaker(bind=engine)() + session.add(object_group) + session.commit() + + # Grab the ID now before making a new session to avoid a Detached error + # See http://sqlalche.me/e/bhk3 for more info. + object_group_id = object_group.id + + session = sqlalchemy.orm.sessionmaker(bind=engine)() + retrieved_group = session.query( + objects.ObjectGroup + ).filter( + objects.ObjectGroup.id == object_group_id + ).one() + session.commit() + + self.assertEqual("Group1", retrieved_group.object_group)