Source code for mcvirt.rpc.ssl_socket

"""Provides methods for wrapping Pyro methods with SSL"""
# Copyright (c) 2016 - I.T. Dev Ltd
#
# This file is part of MCVirt.
#
# MCVirt is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# MCVirt is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with MCVirt.  If not, see <http://www.gnu.org/licenses/>

from Pyro4 import socketutil
import ssl
import socket

from mcvirt.rpc.certificate_generator_factory import CertificateGeneratorFactory


[docs]class SSLSocket(object): """Provides methods for wrapping Pyro methods with SSL""" @staticmethod
[docs] def wrap_socket(socket_object, *args, **kwargs): """Wrap a Pyro socket connection with SSL""" server_side = ('bind' in kwargs.keys()) ssl_kwargs = { 'do_handshake_on_connect': True, 'ssl_version': ssl.PROTOCOL_TLSv1, 'server_side': server_side } cert_gen_factory = CertificateGeneratorFactory() if server_side: cert_gen = cert_gen_factory.get_cert_generator(server='localhost') cert_gen.check_certificates(check_client=False) ssl_kwargs['keyfile'] = cert_gen.server_key_file ssl_kwargs['certfile'] = cert_gen.server_pub_file else: # Determine if hostname is an IP address try: socket.inet_aton(kwargs['connect'][0]) hostname = socket.gethostbyaddr(kwargs['connect'][0])[0] except socket.error: hostname = kwargs['connect'][0] cert_gen = cert_gen_factory.get_cert_generator(hostname) ssl_kwargs['cert_reqs'] = ssl.CERT_REQUIRED ssl_kwargs['ca_certs'] = cert_gen.ca_pub_file return ssl.wrap_socket(socket_object, **ssl_kwargs)
@staticmethod
[docs] def create_ssl_socket(*args, **kwargs): """Override the Pyro createSocket method and wrap with SSL""" socket = socketutil.createSocket(*args, **kwargs) ssl_socket = SSLSocket.wrap_socket(socket, *args, **kwargs) return ssl_socket
@staticmethod
[docs] def create_broadcast_ssl_socket(*args, **kwargs): """Override the Pyro createBroadcastSocket method and wrap with SSL""" socket = socketutil.createBroadcastSocket(*args, **kwargs) ssl_socket = SSLSocket.wrap_socket(socket, *args, **kwargs) return ssl_socket