src.cross_validate
Module Contents
Classes
- class src.cross_validate.CrossValidation(saving_dir: str, n_splits: int = 5)
- device
- fit(model: src.solutions.base_solution.BaseSolution, X: pandas.DataFrame, y: pandas.DataFrame) pandas.DataFrame
Makes average fold prediction
- Parameters
model – predictor from BaseSolution class
X – Dataframe that has text_id and full_text columns
y – Dataframe that has text_id, cohesion, … columns
- Returns
Dataframe with class scores for each split and overall CV score
- predict(X: pandas.DataFrame) pandas.DataFrame
Makes average fold prediction
- Parameters
X – Dataframe that have text_id and full_text columns
- Returns
prediction Dataframe that have text_id, cohesion, … columns
- save(path: Union[str, pathlib.Path])
- load(path: Union[str, pathlib.Path], predictor: src.solutions.base_solution.BaseSolution)