Source code for mcvirt.client.rpc

"""Provide class for connecting to RPC daemon"""

# Copyright (c) 2016 - I.T. Dev Ltd
#
# This file is part of MCVirt.
#
# MCVirt is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# MCVirt is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with MCVirt.  If not, see <http://www.gnu.org/licenses/>

import Pyro4

import mcvirt.exceptions  # Import necessary for Pyro # noqa
from mcvirt.exceptions import AuthenticationError
from mcvirt.utils import get_hostname
from mcvirt.rpc.ssl_socket import SSLSocket
from mcvirt.rpc.constants import Annotations


[docs]class Connection(object): """Connection class, providing connections to the Pyro MCVirt daemon""" NS_PORT = 9090 SESSION_OBJECT = 'session' def __init__(self, username=None, password=None, session_id=None, host=None, ignore_cluster=False): """Store member variables for connecting""" # If the host is not provided, default to the local host self.__host = host if host is not None else get_hostname() Pyro4.config.USE_MSG_WAITALL = False Pyro4.config.CREATE_SOCKET_METHOD = SSLSocket.create_ssl_socket Pyro4.config.CREATE_BROADCAST_SOCKET_METHOD = SSLSocket.create_broadcast_ssl_socket self.__username = username self.__ignore_drbd = False self.__ignore_cluster = ignore_cluster if 'proxy_user' in dir(Pyro4.current_context): self.__proxy_username = Pyro4.current_context.proxy_user else: self.__proxy_username = None # Store the passed session_id so that it may be used for the initial connection self.__session_id = session_id # Perform an initial connection to obtain/verify the session ID self.__session_id = self.__get_session(password=password) def __get_session(self, password): """Obtain a session ID""" try: # Attempt to obtain a connection and obtain a session ID session_object = self.get_connection(object_name=self.SESSION_OBJECT, password=password) session_id = session_object._pyroHandshake[Annotations.SESSION_ID] return session_id except Pyro4.errors.CommunicationError, e: raise AuthenticationError(str(e)) except Pyro4.errors.NamingError, e: raise mcvirt.exceptions.InaccessibleNodeException( 'MCVirt nameserver/daemon is not running on node %s' % self.__host ) def _get_auth_obj(self, password=None): """Setup annotations for authentication""" auth_dict = { Annotations.USERNAME: self.__username } if password: auth_dict[Annotations.PASSWORD] = password elif self.__session_id: auth_dict[Annotations.SESSION_ID] = self.__session_id if self.__proxy_username: auth_dict[Annotations.PROXY_USER] = self.__proxy_username if 'has_lock' in dir(Pyro4.current_context): auth_dict[Annotations.HAS_LOCK] = Pyro4.current_context.has_lock auth_dict[Annotations.IGNORE_CLUSTER] = self.__ignore_cluster if 'ignore_cluster' in dir(Pyro4.current_context): auth_dict[Annotations.IGNORE_CLUSTER] |= Pyro4.current_context.ignore_cluster auth_dict[Annotations.IGNORE_Drbd] = self.__ignore_drbd if 'ignore_drbd' in dir(Pyro4.current_context): auth_dict[Annotations.IGNORE_Drbd] |= Pyro4.current_context.ignore_drbd return auth_dict
[docs] def get_connection(self, object_name, password=None): """Obtain a connection from pyro for a given object""" # Obtain a connection to the name server on the localhost ns = Pyro4.naming.locateNS(host=self.__host, port=self.NS_PORT, broadcast=False) class AuthProxy(Pyro4.Proxy): def _pyroValidateHandshake(self, data): # Override upstream # noqa self._pyroHandshake[Annotations.SESSION_ID] = data # Create a Proxy object, using the overriden Proxy class and return. proxy = AuthProxy(ns.lookup(object_name)) proxy._pyroHandshake = self._get_auth_obj(password=password) proxy._pyroBind() return proxy
[docs] def ignore_drbd(self): """Set flag to ignore DRBD""" self.__ignore_drbd = True
[docs] def ignore_cluster(self): """Set flag to ignore cluster""" self.__ignore_cluster = True
[docs] def annotate_object(self, object_ref): """Add authentication attributes to remote object""" object_ref._pyroHandshake = self._get_auth_obj()
@property def session_id(self): """Property for the session ID""" return self.__session_id