src.cross_validate

Module Contents

Classes

CrossValidation

class src.cross_validate.CrossValidation(saving_dir: str, n_splits: int = 5)
device
fit(model: src.solutions.base_solution.BaseSolution, X: pandas.DataFrame, y: pandas.DataFrame) pandas.DataFrame

Makes average fold prediction

Parameters
  • model – predictor from BaseSolution class

  • X – Dataframe that has text_id and full_text columns

  • y – Dataframe that has text_id, cohesion, … columns

Returns

Dataframe with class scores for each split and overall CV score

predict(X: pandas.DataFrame) pandas.DataFrame

Makes average fold prediction

Parameters

X – Dataframe that have text_id and full_text columns

Returns

prediction Dataframe that have text_id, cohesion, … columns

save(path: Union[str, pathlib.Path])
load(path: Union[str, pathlib.Path], predictor: src.solutions.base_solution.BaseSolution)