Source code for orbdot.star_planet

"""StarPlanet
==========
This module defines the ``StarPlanet`` class, which combines the data, methods, and attributes
needed to study long-term variations in the orbits of exoplanets.
"""

import numpy as np

import orbdot.tools.utilities as utl
from orbdot.joint_fit import JointFit
from orbdot.nested_sampling import NestedSampling
from orbdot.radial_velocity import RadialVelocity
from orbdot.transit_duration import TransitDuration
from orbdot.transit_timing import TransitTiming


[docs] class StarPlanet( TransitTiming, RadialVelocity, TransitDuration, JointFit, NestedSampling ): """A ``StarPlanet`` class instance represents a star-planet system and acts as an interface for the core capabilities of the OrbDot package. It combines the data, methods, and attributes needed to run model fitting routines and interpret the results. """
[docs] def __init__(self, settings_file, planet_num=0): """Initializes the StarPlanet class. Parameters ---------- settings_file : str Path to the main settings file. planet_num : int, optional Planet number in case of multi-planet systems (default is 0). """ # define the complete set of allowed parameters in the RV and timing models self.legal_params = ( "t0", "P0", "e0", "i0", "w0", "O0", # orbital elements "ecosw", "esinw", "sq_ecosw", "sq_esinw", # coupled parameters "PdE", "wdE", "edE", "idE", "OdE", # time-dependent parameters "K", "v0", "jit", "dvdt", "ddvdt", "K_tide", ) # radial velocity # load settings file and merge with defaults args = utl.merge_dictionaries("default_settings_file.json", settings_file) # load system info file and merge with defaults self.sys_info = utl.merge_dictionaries( "default_info_file.json", args["system_info_file"] ) # load plot settings file and merge with defaults self.plot_settings = utl.merge_dictionaries( "default_plot_settings.json", args["plot_settings_file"] ) # define the star and planet names self.planet_index = planet_num self.star_name = self.sys_info["star_name"] self.planet_name = ( self.sys_info["star_name"] + self.sys_info["planets"][planet_num] ) # define the main directory for saving the results self.main_save_dir = args["main_save_dir"] + self.star_name + "/" # update the plot settings with the planet name self.plot_settings["RV_PLOT"]["title"] = self.planet_name self.plot_settings["TTV_PLOT"]["title"] = self.planet_name print(f"\nInitializing {self.planet_name} instance...\n") # specify default values for model parameters (retrieved from the ``system_info_file``) default_values = utl.assign_default_values(self.sys_info, planet_num) # print the default parameter values for convenience print(f" {self.planet_name} default values: {default_values}\n") # initialize the TransitTiming class if args["TTV_fit"]["data_file"] != "None": # define save directory and load data args["TTV_fit"]["save_dir"] = ( self.main_save_dir + args["TTV_fit"]["save_dir"] ) self.ttv_data_filename = args["TTV_fit"]["data_file"] self.ttv_data = utl.read_ttv_data( filename=self.ttv_data_filename, delim=args["TTV_fit"]["data_delimiter"] ) # initialize class instance TransitTiming.__init__(self, args["TTV_fit"]) # initialize the RadialVelocity class if args["RV_fit"]["data_file"] != "None": # define save directory and load data args["RV_fit"]["save_dir"] = self.main_save_dir + args["RV_fit"]["save_dir"] self.rv_data_filename = args["RV_fit"]["data_file"] self.rv_data = utl.read_rv_data( filename=self.rv_data_filename, delim=args["RV_fit"]["data_delimiter"] ) # adjust the priors and default values for multi-instrument RV parameters try: for p in ("v0", "jit"): default_values[p] = list(np.zeros(self.rv_data["num_src"])) prior_shape = np.shape(args["prior"][p]) if prior_shape == (self.rv_data["num_src"], 3): pass elif prior_shape == (3,): args["prior"][p] = [args["prior"][p]] * self.rv_data["num_src"] else: raise ValueError(f"Invalid prior for {p} given # of RV sources") except TypeError: pass # initialize class instance RadialVelocity.__init__(self, args["RV_fit"]) # initialize the TransitDuration class if args["TDV_fit"]["data_file"] != "None": # define save directory and load data args["TDV_fit"]["save_dir"] = ( self.main_save_dir + args["TDV_fit"]["save_dir"] ) self.tdv_data_filename = args["TDV_fit"]["data_file"] self.tdv_data = utl.read_tdv_data( filename=self.tdv_data_filename, delim=args["TDV_fit"]["data_delimiter"] ) # initialize class instance TransitDuration.__init__(self, args["TDV_fit"], self.sys_info) # initialize the JointFit class args["joint_fit"]["save_dir"] = ( self.main_save_dir + args["joint_fit"]["save_dir"] ) JointFit.__init__(self, args["joint_fit"]) # initiate the NestedSampling class NestedSampling.__init__(self, default_values, args["prior"])
[docs] def update_default(self, parameter, new_value): """Updates the default (fixed) value for the specified parameter. The default value will be used in a model fit if the parameter is not allowed to vary. Parameters ---------- parameter : str The parameter name. new_value : float The new parameter value. Returns ------- None The default value for the specified parameter is updated. """ multi_source = ["v0", "jit"] multi_types = ["RV zero velocity", "RV jitter"] # this is more complex for multi-instrument parameters if parameter.split("_")[0] in multi_source: for i, p in enumerate(multi_source): if parameter.split("_")[0] == p: if len(parameter.split("_")) == 1: raise ValueError( "To update the fixed value for {} please specify the " "instrument by entering the \n parameter as '{}_tag', " "where 'tag' is one of {} for RV source(s) {}".format( multi_types[i], p, self.rv_data["src_tags"], self.rv_data["src_names"], ) ) try: ind = np.where( np.array(self.rv_data["src_tags"]) == parameter.split("_")[1] )[0][0] self.fixed[p][ind] = new_value # print updated parameter print( f"* Default value for '{parameter}' updated to: {self.fixed[p][ind]}\n" ) except IndexError: raise ValueError( "Error in updating fixed value for {}, must be specified " "with the format '{}_tag', \n where 'tag' is one of {}" " for RV source(s) {}.".format( multi_types[i], p, self.rv_data["src_tags"], self.rv_data["src_names"], ) ) else: if ( parameter not in self.legal_params ): # check if parameter name is incorrect raise ValueError( f"'{parameter}' is not a variable in the models, allowed parameters are: " f"{self.legal_params}.\n For more information, see ReadMe file or documentation in " "the NestedSampling class file." ) self.fixed[parameter] = new_value # update fixed value # print updated parameter print( f"* Default value for '{parameter}' updated to: {self.fixed[parameter]}\n" ) return
[docs] def update_prior(self, parameter, new_prior): """Updates the prior distribution for the specified parameter. Parameters ---------- parameter : str The parameter name. new_prior : list A list three of values specifying the prior distribution, where the first element is the type of prior (``"uniform"``, ``"gaussian"``, or ``"log"``), and subsequent elements define the distribution. Returns ------- None The prior distribution for the specified parameter is updated. """ multi_source = ["v0", "jit"] multi_types = ["RV systemic velocity", "RV jitter"] if len(new_prior) < 3: raise ValueError( f"The prior on {parameter} cannot be updated to {new_prior}, as it is not in " "the correct format.\nThe allowed formats are:\n" ' Gaussian -> list : ["gaussian", mean, std]\n' ' Log-Uniform -> list : ["log", log10(min), log10(max)]\n' ' Uniform -> list : ["uniform", min, max]\n\n' ) # this is more complex for multi-instrument parameters if parameter.split("_")[0] in multi_source: for i, p in enumerate(multi_source): if parameter.split("_")[0] == p: if len(parameter.split("_")) == 1: raise ValueError( "To update the prior on {} please specify the instrument " "by entering the parameter \n as '{}_tag', where " "'tag' is one of {} for RV source(s) {}".format( multi_types[i], p, self.rv_data["src_tags"], self.rv_data["src_names"], ) ) try: ind = np.where( np.array(self.rv_data["src_tags"]) == parameter.split("_")[1] )[0][0] self.prior[p][ind] = new_prior except IndexError: raise ValueError( "Error in updating prior on {}, must be specified with " "the format '{}_tag', \n where 'tag' is one of {} for " "RV source(s) {}.".format( multi_types[i], p, self.rv_data["src_tags"], self.rv_data["src_names"], ) ) else: if ( parameter not in self.legal_params ): # check if parameter name is incorrect raise ValueError( f"'{parameter}' is not a variable in any of the models, allowed parameters" f" are:\n{self.legal_params}\n\nSee the OrbDot documentation for more information " "on model parameters." ) self.prior[parameter] = new_prior # update prior # print updated prior print(f"* Prior for '{parameter}' updated to: {new_prior} *\n") return