pykoop.score_trajectory
- score_trajectory(X_predicted, X_expected, n_steps=None, discount_factor=1, regression_metric='neg_mean_squared_error', regression_metric_kw=None, error_score=nan, min_samples=1, episode_feature=False)
Score a predicted data matrix compared to an expected data matrix.
- Parameters:
X_predicted (np.ndarray) – Predicted state data matrix.
X_expected (np.ndarray) – Expected state data matrix.
n_steps (Optional[int]) – Number of steps ahead to predict. If
Noneor longer than the episode, will score the entire episode.discount_factor (float) – Discount factor used to weight the error timeseries. Should be positive, with magnitude 1 or slightly less. The error at each timestep is weighted by
discount_factor**k, wherekis the timestep.regression_metric (Union[str, Callable]) –
Regression metric to use. One of
'explained_variance','neg_mean_absolute_error','neg_mean_squared_error','neg_mean_squared_log_error','neg_median_absolute_error','r2', or'neg_mean_absolute_percentage_error',
which are existing
scikit-learnregression metrics [1]. Can also directly specify a function with the same keyword arguments as thescikit-learnones. That is, at leasty_true,y_pred,sample_weight, andmultioutput.regression_metric_kw (Optional[Dict[str, Any]]) – Keyword arguments for
regression_method. Ifsample_weightkeyword argument is specified,discount_factoris ignored.error_score (Union[str, float]) – Value to assign to the score if
X_predictedhas diverged or if an error has occured in estimator fitting. If set to'raise', aValueErroris raised. If a numerical value is given, asklearn.exceptions.FitFailedWarningwarning is raised and the specified score is returned. The error score defines the worst possible score. If a score is finite but lower than the error score, the error score will be returned instead.min_samples (int) – Number of samples in initial condition.
episode_feature (bool) – True if first feature indicates which episode a timestep is from.
- Returns:
Score (greater is better).
- Return type:
- Raises:
ValueError – If
error_score='raise'and an error occurs in scoring.
References