Source code for exoctk.atmospheric_retrievals.platon_wrapper

"""A wrapper around the ``platon`` atmospheric retrieval tool.

This module serves as a wrapper around the atmospheric retrieval
software for ``platon``.  It provides methods for performing retrievals
through multinested sampling and MCMC methods.  For more information
about ``platon``, please see ``https://platon.readthedocs.io``.  For
examples of how to use this software, see the ``examples.py`` module.

Authors
-------

    - Matthew Bourque

Use
---

    Users can perform the atmospheric retrieval by instantiating a
    ``PlatonWrapper`` object and passing fit parameters within the
    python environment.  An example of this is provided below.  For
    more examples of how to use this software, including ways to use
    AWS for performing computations, see the ``examples.py`` module,
    or the ``atmospheric_retrievals_demo.ipynb`` notebook under the
    ``exoctk/notebooks/`` directory.
    ::

        import numpy as np
        from platon.constants import R_sun, R_jup, M_jup
        from exoctk.atmospheric_retrievals.platon_wrapper import PlatonWrapper

        # Build dictionary of parameters you wish to fit
        params = {
            'Rs': 1.19,  # Required
            'Mp': 0.73,  # Required
            'Rp': 1.4,  # Required
            'T': 1200.0,  # Required
            'logZ': 0,  # Optional
            'CO_ratio': 0.53,  # Optional
            'log_cloudtop_P': 4,  # Optional
            'log_scatt_factor': 0,  # Optional
            'scatt_slope': 4,  # Optional
            'error_multiple': 1,  # Optional
            'T_star': 6091}  # Optional

        # Initialize PlatonWrapper object and set the parameters
        pw = PlatonWrapper()
        pw.set_parameters(params)

        # Add any additional fit parameters
        R_guess = 1.4 * R_jup
        T_guess = 1200
        pw.fit_info.add_gaussian_fit_param('Rs', 0.02*R_sun)
        pw.fit_info.add_gaussian_fit_param('Mp', 0.04*M_jup)
        pw.fit_info.add_uniform_fit_param('Rp', 0.9*R_guess, 1.1*R_guess)
        pw.fit_info.add_uniform_fit_param('T', 0.5*T_guess, 1.5*T_guess)
        pw.fit_info.add_uniform_fit_param("log_scatt_factor", 0, 1)
        pw.fit_info.add_uniform_fit_param("logZ", -1, 3)
        pw.fit_info.add_uniform_fit_param("log_cloudtop_P", -0.99, 5)
        pw.fit_info.add_uniform_fit_param("error_multiple", 0.5, 5)

        # Define bins, depths, and errors
        pw.wavelengths = 1e-6*np.array([1.119, 1.138, 1.157, 1.175, 1.194, 1.213, 1.232, 1.251, 1.270, 1.288, 1.307, 1.326, 1.345, 1.364, 1.383, 1.401, 1.420, 1.439, 1.458, 1.477, 1.496, 1.515, 1.533, 1.552, 1.571, 1.590, 1.609, 1.628])
        pw.bins = [[w-0.0095e-6, w+0.0095e-6] for w in pw.wavelengths]
        pw.depths = 1e-6 * np.array([14512.7, 14546.5, 14566.3, 14523.1, 14528.7, 14549.9, 14571.8, 14538.6, 14522.2, 14538.4, 14535.9, 14604.5, 14685.0, 14779.0, 14752.1, 14788.8, 14705.2, 14701.7, 14677.7, 14695.1, 14722.3, 14641.4, 14676.8, 14666.2, 14642.5, 14594.1, 14530.1, 14642.1])
        pw.errors = 1e-6 * np.array([50.6, 35.5, 35.2, 34.6, 34.1, 33.7, 33.5, 33.6, 33.8, 33.7, 33.4, 33.4, 33.5, 33.9, 34.4, 34.5, 34.7, 35.0, 35.4, 35.9, 36.4, 36.6, 37.1, 37.8, 38.6, 39.2, 39.9, 40.8])

        # Perform the retrieval by your favorite method
        pw.retrieve('multinest')  # OR
        pw.retrieve_('emcee')

        # Save the results to an output file
        pw.save_results()

        # Save a plot of the results
        pw.make_plot()

Dependencies
------------

    - ``corner``
    - ``exoctk``
    - ``matplotlib``
    - ``platon``
"""

import argparse
import datetime
import getpass
import logging
import os
import pickle
import socket
import sys
import time

import corner
import matplotlib
from platon.retriever import Retriever
from platon.constants import R_sun, R_jup, M_jup

from exoctk.atmospheric_retrievals.aws_tools import build_environment
from exoctk.atmospheric_retrievals.aws_tools import log_output
from exoctk.atmospheric_retrievals.aws_tools import start_ec2
from exoctk.atmospheric_retrievals.aws_tools import stop_ec2
from exoctk.atmospheric_retrievals.aws_tools import transfer_from_ec2
from exoctk.atmospheric_retrievals.aws_tools import transfer_to_ec2


def _apply_factors(params):
    """Apply appropriate multiplication factors to parameters.

    Parameters
    ----------
    params : dict
        A dictionary of parameters and their values for running the
        software.  See "Use" documentation for further details.
    """

    params['Rs'] = params['Rs'] * R_sun
    params['Mp'] = params['Mp'] * M_jup
    params['Rp'] = params['Rp'] * R_jup

    return params


def _log_execution_time(start_time):
    """Logs the execution time of the retrieval.

    Parameters
    ----------
    start_time : obj
        The start time of the retrieval execution
    """

    end_time = time.time()

    # Log execution time
    hours, remainder_time = divmod(end_time - start_time, 60 * 60)
    minutes, seconds = divmod(remainder_time, 60)
    logging.info('Retrieval Execution Time: {}:{}:{}'.format(int(hours), int(minutes), int(seconds)))


def _parse_args():
    """Parses and returns command line arguments.

    Returns
    -------
    args : obj
        An object containing the command line argument values.
    """

    parser = argparse.ArgumentParser()
    parser.add_argument('method', type=str, help='Retrieval method (either "emcee" or "multinest"')
    args = parser.parse_args()

    return args


def _validate_parameters(supplied_params):
    """Ensure the supplied parameters are valid.  Throw assertion
    errors if they are not.

    Parameters
    ----------
    supplied_params : dict
        A dictionary of parameters and their values for running the
        software.  See "Use" documentation for further details.
    """

    # Define the parameters, their data types, and if they are required
    parameters = [('Rs', float, True),
                  ('Mp', float, True),
                  ('Rp', float, True),
                  ('T', float, True),
                  ('logZ', int, False),
                  ('CO_ratio', float, False),
                  ('log_cloudtop_P', int, False),
                  ('log_scatt_factor', int, False),
                  ('scatt_slope', int, False),
                  ('error_multiple', int, False),
                  ('T_star', int, False)]

    for parameter in parameters:
        name, data_type, required = parameter

        # Ensure that all required parameters are supplied
        if required:
            assert name in supplied_params, '{} missing from parameters'.format(parameter)

        # Ensure the supplied parameter is of a valid data type
        if name in supplied_params:
            assert type(supplied_params[name]) == data_type, '{} is not of type {}'.format(parameter, data_type)


[docs]class PlatonWrapper(): """Class object for running the platon atmospheric retrieval software.""" def __init__(self): """Initialize the class object.""" self.ec2_id = '' self.output_results = 'results.dat' self.output_plot = 'corner.png' self.retriever = Retriever() self.ssh_file = '' self.aws = False self._configure_logging() def _configure_logging(self): """Creates a log file that logs the execution of the script. Log files are written to a ``logs/`` subdirectory within the current working directory. Returns ------- start_time : obj The start time of the script execution """ # Define save location log_file = 'logs/{}.log'.format(datetime.datetime.now().strftime('%Y-%m-%d-%H-%M')) # Create the subdirectory if necessary if not os.path.exists('logs/'): os.mkdir('logs/') # Make sure no other root handlers exist before configuring the logger for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) # Create the log file logging.basicConfig(filename=log_file, format='%(asctime)s %(levelname)s: %(message)s', datefmt='%m/%d/%Y %H:%M:%S %p', level=logging.INFO) print('Log file initialized to {}'.format(log_file)) # Log environment information logging.info('User: ' + getpass.getuser()) logging.info('System: ' + socket.gethostname()) logging.info('Python Version: ' + sys.version.replace('\n', '')) logging.info('Python Executable Path: ' + sys.executable) self.start_time = time.time()
[docs] def make_plot(self): """Create a corner plot that shows the results of the retrieval.""" print('Creating corner plot') logging.info('Creating corner plot') matplotlib.rcParams['text.usetex'] = False if self.method == 'emcee': fig = corner.corner(self.result.flatchain, range=[0.99] * self.result.flatchain.shape[1], labels=self.fit_info.fit_param_names) elif self.method == 'multinest': fig = corner.corner(self.result.samples, weights=self.result.weights, range=[0.99] * self.result.samples.shape[1], labels=self.fit_info.fit_param_names) # Save the results self.output_plot = '{}_corner.png'.format(self.method) fig.savefig(self.output_plot) print('Corner plot saved to {}'.format(self.output_plot)) logging.info('Corner plot saved to {}'.format(self.output_plot))
[docs] def retrieve(self, method): """Perform the atmopsheric retrieval via the given method Parameters ---------- method : str The method by which to perform atmospheric retrievals. Can either be ``emcee`` or ``multinest``.""" print('Performing atmopsheric retrievals via {}'.format(method)) logging.info('Performing atmopsheric retrievals via {}'.format(method)) # Ensure that the method parameter is valid assert method in ['multinest', 'emcee'], 'Unrecognized method: {}'.format(method) self.method = method # For processing on AWS if self.aws: # Start or create an EC2 instance instance, key, client = start_ec2(self.ssh_file, self.ec2_id) # Build the environment on EC2 instance if necessary if self.build_required: build_environment(instance, key, client) # Transfer object file to EC2 transfer_to_ec2(instance, key, client, 'pw.obj') # Connect to the EC2 instance and run commands command = './exoctk/exoctk/atmospheric_retrievals/exoctk-env-init.sh python exoctk/exoctk/atmospheric_retrievals/platon_wrapper.py {}'.format(self.method) client.connect(hostname=instance.public_dns_name, username='ec2-user', pkey=key) stdin, stdout, stderr = client.exec_command(command) output = stdout.read() errors = stderr.read() log_output(output) log_output(errors) # Trasfer output products from EC2 to user if self.method == 'emcee': transfer_from_ec2(instance, key, client, 'emcee_results.obj') transfer_from_ec2(instance, key, client, 'emcee_corner.png') elif self.method == 'multinest': transfer_from_ec2(instance, key, client, 'multinest_results.dat') transfer_from_ec2(instance, key, client, 'multinest_corner.png') # Terminate or stop the EC2 instance stop_ec2(self.ec2_id, instance) # For processing locally else: if self.method == 'emcee': self.result = self.retriever.run_emcee(self.bins, self.depths, self.errors, self.fit_info) elif self.method == 'multinest': self.result = self.retriever.run_multinest(self.bins, self.depths, self.errors, self.fit_info, plot_best=False) _log_execution_time(self.start_time)
[docs] def save_results(self): """Save the results of the retrieval to an output file.""" print('Saving results') logging.info('Saving results') # Save the results if self.method == 'multinest': self.output_results = 'multinest_results.dat' with open(self.output_results, 'w') as f: f.write(str(self.result)) elif self.method == 'emcee': self.output_results = 'emcee_results.obj' with open(self.output_results, 'wb') as f: pickle.dump(self.result, f) print('Results file saved to {}'.format(self.output_results)) logging.info('Results file saved to {}'.format(self.output_results))
[docs] def set_parameters(self, params): """Set necessary parameters to perform the retrieval. Required parameters include ``Rs``, ``Mp``, ``Rp``, and ``T``. Optional parameters include ``logZ``, ``CO_ratio``, ``log_cloudtop_P``, ``log_scatt_factor``, ``scatt_slope``, ``error_multiple``, and ``T_star``. Parameters ---------- params : str or dict Either a path to a params file to use, or a dictionary of parameters and their values for running the software. See "Use" documentation for further details. """ print('Setting parameters: {}'.format(params)) logging.info('Setting parameters: {}'.format(params)) _validate_parameters(params) _apply_factors(params) self.params = params self.fit_info = self.retriever.get_default_fit_info(**self.params)
[docs] def use_aws(self, ssh_file, ec2_id): """Sets appropriate parameters in order to perform processing using an AWS EC2 instance. Parameters ---------- ssh_file : str The path to a public SSH key used to connect to the EC2 instance. ec2_id : str A template id that points to a pre-built EC2 instance. """ print('Using AWS for processing') logging.info('Using AWS for processing') self.ssh_file = ssh_file self.ec2_id = ec2_id # If the ec2_id is a template ID, then building the instance is required if ec2_id.split('-')[0] == 'lt': self.build_required = True else: self.build_required = False # Write out object to file with open('pw.obj', 'wb') as f: pickle.dump(self, f) print('Saved PlatonWrapper object to pw.obj') logging.info('Saved PlatonWrapper object to pw.obj') self.aws = True
if __name__ == '__main__': # Parse arguments args = _parse_args() # Read in PlatonWrapper object with open('pw.obj', 'rb') as f: pw = pickle.load(f) # Do some retrievals pw.retrieve(args.method) # Save results pw.save_results() # Make corner plot of results pw.make_plot()