diff --git a/components/discovery/discoverycomponent.cpp b/components/discovery/discoverycomponent.cpp index ce88f2028..7b65e60c1 100644 --- a/components/discovery/discoverycomponent.cpp +++ b/components/discovery/discoverycomponent.cpp @@ -24,6 +24,8 @@ void DiscoveryComponent::Start(void) m_DiscoveryEndpoint->RegisterMethodSource("discovery::RegisterComponent"); m_DiscoveryEndpoint->RegisterMethodHandler("discovery::NewComponent", bind_weak(&DiscoveryComponent::NewComponentMessageHandler, shared_from_this())); + m_DiscoveryEndpoint->RegisterMethodHandler("discovery::Welcome", + bind_weak(&DiscoveryComponent::WelcomeMessageHandler, shared_from_this())); GetEndpointManager()->ForeachEndpoint(bind(&DiscoveryComponent::NewEndpointHandler, this, _1)); GetEndpointManager()->OnNewEndpoint += bind_weak(&DiscoveryComponent::NewEndpointHandler, shared_from_this()); @@ -71,6 +73,8 @@ int DiscoveryComponent::NewEndpointHandler(const NewEndpointEventArgs& neea) neea.Endpoint->RegisterMethodSource("discovery::RegisterComponent"); } + neea.Endpoint->RegisterMethodSource("discovery::Welcome"); + /* TODO: implement message broker authorisation */ neea.Endpoint->RegisterMethodSource("discovery::NewComponent"); @@ -181,11 +185,61 @@ int DiscoveryComponent::NewIdentityHandler(const EventArgs& ea) return 0; } - // TODO: send discovery::Welcome message - // TODO: add subscriptions/provides to this endpoint + FinishDiscoverySetup(endpoint); + return 0; } +int DiscoveryComponent::WelcomeMessageHandler(const NewRequestEventArgs& nrea) +{ + Endpoint::Ptr endpoint = nrea.Sender; + + if (endpoint->GetHandshakeCounter() >= 2) + return 0; + + endpoint->IncrementHandshakeCounter(); + + if (endpoint->GetHandshakeCounter() >= 2) { + EventArgs ea; + ea.Source = shared_from_this(); + endpoint->OnSessionEstablished(ea); + } + + return 0; +} + +void DiscoveryComponent::FinishDiscoverySetup(Endpoint::Ptr endpoint) +{ + if (endpoint->GetHandshakeCounter() >= 2) + return; + + // we assume the other component _always_ wants + // discovery::Welcome messages from us + endpoint->RegisterMethodSink("discovery::Welcome"); + JsonRpcRequest request; + request.SetMethod("discovery::Welcome"); + GetEndpointManager()->SendUnicastRequest(m_DiscoveryEndpoint, endpoint, request); + + ComponentDiscoveryInfo::Ptr info; + + if (GetComponentDiscoveryInfo(endpoint->GetIdentity(), &info)) { + set::iterator i; + for (i = info->PublishedMethods.begin(); i != info->PublishedMethods.end(); i++) + endpoint->RegisterMethodSource(*i); + + for (i = info->SubscribedMethods.begin(); i != info->SubscribedMethods.end(); i++) + endpoint->RegisterMethodSink(*i); + } + + endpoint->IncrementHandshakeCounter(); + + if (endpoint->GetHandshakeCounter() >= 2) { + EventArgs ea; + ea.Source = shared_from_this(); + endpoint->OnSessionEstablished(ea); + } +} + void DiscoveryComponent::SendDiscoveryMessage(string method, string identity, Endpoint::Ptr recipient) { JsonRpcRequest request; @@ -257,7 +311,12 @@ void DiscoveryComponent::ProcessDiscoveryMessage(string identity, DiscoveryMessa m_Components[identity] = info; - SendDiscoveryMessage("discovery::NewComponent", identity, Endpoint::Ptr()); + if (IsBroker()) + SendDiscoveryMessage("discovery::NewComponent", identity, Endpoint::Ptr()); + + Endpoint::Ptr endpoint = GetEndpointManager()->GetEndpointByIdentity(identity); + if (endpoint) + FinishDiscoverySetup(endpoint); } int DiscoveryComponent::NewComponentMessageHandler(const NewRequestEventArgs& nrea) @@ -287,7 +346,8 @@ int DiscoveryComponent::ReconnectTimerHandler(const TimerEventArgs& tea) map::iterator i; for (i = m_Components.begin(); i != m_Components.end(); i++) { - if (endpointManager->HasConnectedEndpoint(i->first)) + Endpoint::Ptr endpoint = endpointManager->GetEndpointByIdentity(i->first); + if (endpoint) continue; ComponentDiscoveryInfo::Ptr info = i->second; diff --git a/components/discovery/discoverycomponent.h b/components/discovery/discoverycomponent.h index 2a20c863b..269acf5cb 100644 --- a/components/discovery/discoverycomponent.h +++ b/components/discovery/discoverycomponent.h @@ -31,6 +31,8 @@ private: int NewComponentMessageHandler(const NewRequestEventArgs& nrea); int RegisterComponentMessageHandler(const NewRequestEventArgs& nrea); + int WelcomeMessageHandler(const NewRequestEventArgs& nrea); + void SendDiscoveryMessage(string method, string identity, Endpoint::Ptr recipient); void ProcessDiscoveryMessage(string identity, DiscoveryMessage message); @@ -45,6 +47,8 @@ private: bool IsBroker(void) const; + void FinishDiscoverySetup(Endpoint::Ptr endpoint); + public: virtual string GetName(void) const; virtual void Start(void); diff --git a/icinga/endpoint.cpp b/icinga/endpoint.cpp index 3fc126e3d..8847a76c5 100644 --- a/icinga/endpoint.cpp +++ b/icinga/endpoint.cpp @@ -2,6 +2,11 @@ using namespace icinga; +Endpoint::Endpoint(void) +{ + m_HandshakeCounter = false; +} + string Endpoint::GetIdentity(void) const { return m_Identity; @@ -120,3 +125,13 @@ set::const_iterator Endpoint::EndSources(void) const { return m_MethodSources.end(); } + +void Endpoint::IncrementHandshakeCounter(void) +{ + m_HandshakeCounter++; +} + +unsigned short Endpoint::GetHandshakeCounter(void) const +{ + return m_HandshakeCounter; +} diff --git a/icinga/endpoint.h b/icinga/endpoint.h index 3e14051cf..92235a972 100644 --- a/icinga/endpoint.h +++ b/icinga/endpoint.h @@ -17,6 +17,7 @@ private: string m_Identity; set m_MethodSinks; set m_MethodSources; + unsigned short m_HandshakeCounter; weak_ptr m_EndpointManager; @@ -24,12 +25,17 @@ public: typedef shared_ptr Ptr; typedef weak_ptr WeakPtr; + Endpoint(void); + virtual string GetAddress(void) const = 0; string GetIdentity(void) const; void SetIdentity(string identity); bool HasIdentity(void) const; + void IncrementHandshakeCounter(); + unsigned short GetHandshakeCounter(void) const; + shared_ptr GetEndpointManager(void) const; void SetEndpointManager(weak_ptr manager); diff --git a/icinga/endpointmanager.cpp b/icinga/endpointmanager.cpp index 19e8d9dc1..254176b08 100644 --- a/icinga/endpointmanager.cpp +++ b/icinga/endpointmanager.cpp @@ -147,13 +147,13 @@ void EndpointManager::ForeachEndpoint(function::const_iterator i; for (i = m_Endpoints.begin(); i != m_Endpoints.end(); i++) { if ((*i)->GetIdentity() == identity) - return true; + return *i; } - return false; + return Endpoint::Ptr(); } diff --git a/icinga/endpointmanager.h b/icinga/endpointmanager.h index 2061e36c2..b543826a9 100644 --- a/icinga/endpointmanager.h +++ b/icinga/endpointmanager.h @@ -44,7 +44,7 @@ public: void ForeachEndpoint(function callback); - bool HasConnectedEndpoint(string identity) const; + Endpoint::Ptr GetEndpointByIdentity(string identity) const; Event OnNewEndpoint; };