"""Provide class for RPC daemon."""
# Copyright (c) 2016 - I.T. Dev Ltd
#
# This file is part of MCVirt.
#
# MCVirt is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# MCVirt is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with MCVirt. If not, see <http://www.gnu.org/licenses/>
import atexit
import Pyro4
import signal
import time
from mcvirt.auth.auth import Auth
from mcvirt.auth.permissions import PERMISSIONS
from mcvirt.virtual_machine.factory import Factory as VirtualMachineFactory
from mcvirt.iso.factory import Factory as IsoFactory
from mcvirt.node.network.factory import Factory as NetworkFactory
from mcvirt.virtual_machine.hard_drive.factory import Factory as HardDriveFactory
from mcvirt.auth.factory import Factory as UserFactory
from mcvirt.auth.session import Session
from mcvirt.cluster.cluster import Cluster
from mcvirt.virtual_machine.network_adapter.factory import Factory as NetworkAdapterFactory
from mcvirt.logger import Logger
from mcvirt.node.drbd import Drbd as NodeDrbd
from mcvirt.node.node import Node
from mcvirt.rpc.ssl_socket import SSLSocket
from mcvirt.rpc.certificate_generator_factory import CertificateGeneratorFactory
from mcvirt.node.libvirt_config import LibvirtConfig
from mcvirt.libvirt_connector import LibvirtConnector
from mcvirt.utils import get_hostname
from mcvirt.rpc.constants import Annotations
from mcvirt.syslogger import Syslogger
from mcvirt.rpc.daemon_lock import DaemonLock
from mcvirt.client.rpc import Connection
[docs]class BaseRpcDaemon(Pyro4.Daemon):
"""Override Pyro daemon to add authentication checks and MCVirt integration"""
def __init__(self, *args, **kwargs):
"""Override init to set required configuration and create nameserver connection"""
# Require all methods/classes to be exposed
# DO NOT CHANGE THIS OPTION!
Pyro4.config.REQUIRE_EXPOSE = True
# Perform super method for init of daemon
super(BaseRpcDaemon, self).__init__(*args, **kwargs)
# Store MCVirt instance
self.registered_factories = {}
[docs] def validateHandshake(self, conn, data): # Override name of upstream method # noqa
"""Perform authentication on new connections"""
# Reset session_id for current context
Pyro4.current_context.STARTUP_PERIOD = False
Pyro4.current_context.session_id = None
Pyro4.current_context.username = None
Pyro4.current_context.proxy_user = None
Pyro4.current_context.has_lock = False
Pyro4.current_context.cluster_master = True
# Check and store username from connection
if Annotations.USERNAME not in data:
raise Pyro4.errors.SecurityError('Username and password or Session must be passed')
username = str(data[Annotations.USERNAME])
# If a password has been provided
try:
# @TODO - Re-factor as the logic below is duplicated for SESSION_ID in data clause
if Annotations.PASSWORD in data:
# Store the password and perform authentication check
password = str(data[Annotations.PASSWORD])
session_instance = self.registered_factories['mcvirt_session']
session_id = session_instance.authenticate_user(username=username,
password=password)
if session_id:
Pyro4.current_context.username = username
Pyro4.current_context.session_id = session_id
# If the authenticated user can specify a proxy user, and a proxy user
# has been specified, set this in the current context
user_object = session_instance.get_current_user_object()
if user_object.allow_proxy_user and Annotations.PROXY_USER in data:
Pyro4.current_context.proxy_user = data[Annotations.PROXY_USER]
# If the user is a cluster/connection user, treat this connection
# as a cluster client (the command as been executed on a remote node)
# unless specified otherwise
auth = self.registered_factories['auth']
if user_object.CLUSTER_USER:
if user_object.CLUSTER_USER and Annotations.CLUSTER_MASTER in data:
Pyro4.current_context.cluster_master = data[Annotations.CLUSTER_MASTER]
else:
Pyro4.current_context.cluster_master = False
else:
Pyro4.current_context.cluster_master = True
if user_object.CLUSTER_USER and Annotations.HAS_LOCK in data:
Pyro4.current_context.has_lock = data[Annotations.HAS_LOCK]
else:
Pyro4.current_context.has_lock = False
if (auth.check_permission(PERMISSIONS.CAN_IGNORE_CLUSTER,
user_object=user_object) and
Annotations.IGNORE_CLUSTER in data):
Pyro4.current_context.ignore_cluster = data[Annotations.IGNORE_CLUSTER]
else:
Pyro4.current_context.ignore_cluster = False
if (auth.check_permission(PERMISSIONS.CAN_IGNORE_DRBD,
user_object=user_object) and
Annotations.IGNORE_Drbd in data):
Pyro4.current_context.ignore_drbd = data[Annotations.IGNORE_Drbd]
else:
Pyro4.current_context.ignore_drbd = False
if Pyro4.current_context.cluster_master:
self.registered_factories['cluster'].check_node_versions()
return session_id
# If a session id has been passed, store it and check the
# session_id/username against active sessions
elif Annotations.SESSION_ID in data:
session_id = str(data[Annotations.SESSION_ID])
session_instance = self.registered_factories['mcvirt_session']
if session_instance.authenticate_session(username=username, session=session_id):
Pyro4.current_context.username = username
Pyro4.current_context.session_id = session_id
# Determine if user can provide alternative users
user_object = session_instance.get_current_user_object()
if user_object.allow_proxy_user and Annotations.PROXY_USER in data:
Pyro4.current_context.proxy_user = data[Annotations.PROXY_USER]
# If the user is a cluster/connection user, treat this connection
# as a cluster client (the command as been executed on a remote node)
# unless specified otherwise
auth = self.registered_factories['auth']
if user_object.CLUSTER_USER:
if user_object.CLUSTER_USER and Annotations.CLUSTER_MASTER in data:
Pyro4.current_context.cluster_master = data[Annotations.CLUSTER_MASTER]
else:
Pyro4.current_context.cluster_master = False
else:
Pyro4.current_context.cluster_master = True
if user_object.CLUSTER_USER and Annotations.HAS_LOCK in data:
Pyro4.current_context.has_lock = data[Annotations.HAS_LOCK]
else:
Pyro4.current_context.has_lock = False
if (auth.check_permission(PERMISSIONS.CAN_IGNORE_CLUSTER,
user_object=user_object) and
Annotations.IGNORE_CLUSTER in data):
Pyro4.current_context.ignore_cluster = data[Annotations.IGNORE_CLUSTER]
else:
Pyro4.current_context.ignore_cluster = False
if (auth.check_permission(PERMISSIONS.CAN_IGNORE_DRBD,
user_object=user_object) and
Annotations.IGNORE_Drbd in data):
Pyro4.current_context.ignore_drbd = data[Annotations.IGNORE_Drbd]
else:
Pyro4.current_context.ignore_drbd = False
if Pyro4.current_context.cluster_master:
self.registered_factories['cluster'].check_node_versions()
return session_id
except Pyro4.errors.SecurityError:
raise
except Exception, e:
print str(e)
# If no valid authentication was provided, raise an error
raise Pyro4.errors.SecurityError('Invalid username/password/session')
[docs]class DaemonSession(object):
"""Class for allowing client to obtain the session ID"""
@Pyro4.expose()
[docs] def get_session_id(self):
"""Return the client's current session ID"""
if Pyro4.current_context.session_id:
return Pyro4.current_context.session_id
[docs]class RpcNSMixinDaemon(object):
"""Wrapper for the daemon. Required since the
Pyro daemon class overrides get/setattr and other
built-in object methods
"""
DAEMON = None
def __init__(self):
"""Store required object member variables and create MCVirt object"""
# Initialise Pyro4 with flag to showing that the daemon is being started
Pyro4.current_context.STARTUP_PERIOD = True
# Store nameserver, MCVirt instance and create daemon
self.daemon_lock = DaemonLock()
Pyro4.config.USE_MSG_WAITALL = False
Pyro4.config.CREATE_SOCKET_METHOD = SSLSocket.create_ssl_socket
Pyro4.config.CREATE_BROADCAST_SOCKET_METHOD = SSLSocket.create_broadcast_ssl_socket
Pyro4.config.THREADPOOL_ALLOW_QUEUE = True
Pyro4.config.THREADPOOL_SIZE = 128
self.hostname = get_hostname()
# Ensure that the required SSL certificates exist
ssl_socket = CertificateGeneratorFactory().get_cert_generator('localhost')
ssl_socket.check_certificates(check_client=False)
ssl_socket = None
# Wait for nameserver
self.obtain_connection()
RpcNSMixinDaemon.DAEMON = BaseRpcDaemon(host=self.hostname)
self.register_factories()
# Ensure libvirt is configured
cert_gen_factory = RpcNSMixinDaemon.DAEMON.registered_factories[
'certificate_generator_factory']
cert_gen = cert_gen_factory.get_cert_generator('localhost')
cert_gen.check_certificates()
cert_gen = None
cert_gen_factory = None
atexit.register(self.shutdown, 'atexit', '')
for sig in (signal.SIGABRT, signal.SIGILL, signal.SIGINT,
signal.SIGSEGV, signal.SIGTERM):
signal.signal(sig, self.shutdown)
[docs] def start(self, *args, **kwargs):
"""Start the Pyro daemon"""
Pyro4.current_context.STARTUP_PERIOD = False
with DaemonLock.LOCK:
RpcNSMixinDaemon.DAEMON.requestLoop(*args, **kwargs)
Syslogger.logger().debug('Daemon request loop finished')
[docs] def shutdown(self, signum, frame):
"""Shutdown Pyro Daemon"""
Syslogger.logger().error('Received signal: %s' % signum)
RpcNSMixinDaemon.DAEMON.shutdown()
Syslogger.logger().debug('finisehd shutdown')
[docs] def register(self, obj_or_class, objectId, *args, **kwargs): # Override upstream # noqa
"""Override register to register object with NS."""
Syslogger.logger().debug('Registering object: %s' % objectId)
uri = RpcNSMixinDaemon.DAEMON.register(obj_or_class, *args, **kwargs)
ns = Pyro4.naming.locateNS(host=self.hostname, port=9090, broadcast=False)
ns.register(objectId, uri)
ns = None
RpcNSMixinDaemon.DAEMON.registered_factories[objectId] = obj_or_class
return uri
[docs] def register_factories(self):
"""Register base MCVirt factories with RPC daemon"""
# Register session class
self.register(DaemonSession, objectId='session', force=True)
# Create Virtual machine factory object and register with daemon
virtual_machine_factory = VirtualMachineFactory()
self.register(virtual_machine_factory, objectId='virtual_machine_factory', force=True)
# Create network factory object and register with daemon
network_factory = NetworkFactory()
self.register(network_factory, objectId='network_factory', force=True)
# Create network factory object and register with daemon
hard_drive_factory = HardDriveFactory()
self.register(hard_drive_factory, objectId='hard_drive_factory', force=True)
# Create ISO factory object and register with daemon
iso_factory = IsoFactory()
self.register(iso_factory, objectId='iso_factory', force=True)
# Create auth object and register with daemon
auth = Auth()
self.register(auth, objectId='auth', force=True)
# Create user factory object and register with Daemon
user_factory = UserFactory()
self.register(user_factory, objectId='user_factory', force=True)
# Create cluster object and register with Daemon
cluster = Cluster()
self.register(cluster, objectId='cluster', force=True)
# Create node Drbd object and register with daemon
node_drbd = NodeDrbd()
self.register(node_drbd, objectId='node_drbd', force=True)
# Create network adapter factory and register with daemon
network_adapter_factory = NetworkAdapterFactory()
self.register(network_adapter_factory, objectId='network_adapter_factory', force=True)
# Create node instance and register with daemon
node = Node()
self.register(node, objectId='node', force=True)
# Create logger object and register with daemon
logger = Logger()
self.register(logger, objectId='logger', force=True)
# Create and register SSLSocketFactory object
certificate_generator_factory = CertificateGeneratorFactory()
self.register(certificate_generator_factory,
objectId='certificate_generator_factory', force=True)
# Create libvirt config object and register with daemon
libvirt_config = LibvirtConfig()
self.register(libvirt_config, objectId='libvirt_config', force=True)
# Create and register libvirt connector object
libvirt_connector = LibvirtConnector()
self.register(libvirt_connector, objectId='libvirt_connector', force=True)
# Create an MCVirt session
RpcNSMixinDaemon.DAEMON.registered_factories['mcvirt_session'] = Session()
[docs] def obtain_connection(self):
"""Attempt to obtain a connection to the name server."""
while 1:
try:
Pyro4.naming.locateNS(host=self.hostname, port=9090, broadcast=False)
return
except Exception as e:
Syslogger.logger().warn('Connecting to name server: %s' % str(e))
# Wait for 1 second for name server to come up
time.sleep(1)