debeir.evaluation.cross_validation

  1from enum import Enum
  2from typing import Dict, List, Union
  3
  4import numpy as np
  5from debeir.datasets.types import DatasetTypes, InputExample
  6from sklearn.model_selection import KFold, StratifiedKFold
  7
  8import datasets
  9
 10
 11def split_k_fold(n_fold, data_files):
 12    percentage = 100 // n_fold
 13
 14    vals_ds = datasets.load_dataset('csv', split=[
 15        f'train[{k}%:{k + percentage}%]' for k in range(0, 100, percentage)
 16    ], data_files=data_files)
 17
 18    trains_ds = datasets.load_dataset('csv', split=[
 19        f'train[:{k}%]+train[{k + percentage}%:]' for k in range(0, 100, percentage)
 20    ], data_files=data_files)
 21
 22    return trains_ds, vals_ds
 23
 24
 25class CrossValidatorTypes(Enum):
 26    """
 27    Cross Validator Strategies for separating the dataset
 28    """
 29    Stratified = "StratifiedKFold"
 30    KFold = "KFold"
 31
 32
 33str_to_fn = {
 34    "StratifiedKFold": StratifiedKFold,
 35    "KFold": KFold
 36}
 37
 38
 39class CrossValidator:
 40    """
 41    Cross Validator Class for different types of data_sets
 42
 43    E.g. List -> [[Data], label]
 44         List[Dict] -> {"data": Data, "label": label}
 45         Huggingface Dataset Object -> Data(set="train", label = "label").select(idx)
 46    """
 47
 48    def __init__(self, dataset: Union[List, List[Dict], datasets.Dataset],
 49                 x_idx_label_or_attr: Union[str, int], y_idx_label_or_attr: Union[str, int],
 50                 cross_validator_type: [str, CrossValidatorTypes] = CrossValidatorTypes.Stratified,
 51                 seed=42, n_splits=5):
 52        # self.evaluator = evaluator
 53        self.cross_vali_fn = str_to_fn[cross_validator_type](n_splits=n_splits,
 54                                                             shuffle=True,
 55                                                             random_state=seed)
 56        self.dataset = dataset
 57        self.splits = []
 58
 59        self.x_label = x_idx_label_or_attr
 60        self.y_label = y_idx_label_or_attr
 61
 62        if self.dataset_type is None:
 63            self._determine_dataset_type()
 64            x, y = self.split_fn(x_idx_label_or_attr, y_idx_label_or_attr)
 65            self.splits = self.cross_vali_fn.split(x, y)
 66
 67    def _determine_dataset_type(self):
 68        if isinstance(self.dataset, list):
 69            if isinstance(self.dataset[0], dict):
 70                self.dataset_type = DatasetTypes.ListDict
 71                self.split_fn = self._split_dict
 72            elif isinstance(self.dataset[0], InputExample):
 73                self.dataset_type = DatasetTypes.ListInputExample
 74                self.split_fn = self._split_list
 75            else:
 76                self.dataset_type = DatasetTypes.List
 77                self.split_fn = self._split_list
 78        elif isinstance(self.dataset, datasets.Dataset):
 79            self.dataset_type = DatasetTypes.HuggingfaceDataset
 80            self.split_fn = self._split_dataset
 81        else:
 82            raise NotImplementedError("Unknown Dataset format")
 83
 84    def _split_list(self, *args, **kwargs):
 85        X = np.zeros(len(list(map(lambda k: k[self.x_label], self.dataset))))
 86        Y = map(lambda k: k[self.y_label], self.dataset)
 87
 88        return X, Y
 89
 90    def _split_dict(self, *args, **kwargs):
 91        X = np.zeros(len(list(map(lambda k: k[self.x_label], self.dataset))))
 92        Y = map(lambda k: k[self.y_label], self.dataset)
 93
 94        return X, Y
 95
 96    def _split_dataset(self, *args, **kwargs):
 97        # Rows data doesn't matter
 98        X = np.zeros(self.dataset.num_rows)
 99        Y = self.dataset[self.y_label]
100
101        return X, Y
102
103    def get_fold(self, fold_num: int):
104        """
105
106        :param fold_num: Which fold to pick
107        :return:
108        """
109
110        split = self.splits[fold_num]
111
112        return {
113            "train_idxs": split[0],
114            "val_idxs": split[1]
115        }
def split_k_fold(n_fold, data_files):
12def split_k_fold(n_fold, data_files):
13    percentage = 100 // n_fold
14
15    vals_ds = datasets.load_dataset('csv', split=[
16        f'train[{k}%:{k + percentage}%]' for k in range(0, 100, percentage)
17    ], data_files=data_files)
18
19    trains_ds = datasets.load_dataset('csv', split=[
20        f'train[:{k}%]+train[{k + percentage}%:]' for k in range(0, 100, percentage)
21    ], data_files=data_files)
22
23    return trains_ds, vals_ds
class CrossValidatorTypes(enum.Enum):
26class CrossValidatorTypes(Enum):
27    """
28    Cross Validator Strategies for separating the dataset
29    """
30    Stratified = "StratifiedKFold"
31    KFold = "KFold"

Cross Validator Strategies for separating the dataset

Stratified = <CrossValidatorTypes.Stratified: 'StratifiedKFold'>
KFold = <CrossValidatorTypes.KFold: 'KFold'>
Inherited Members
enum.Enum
name
value
class CrossValidator:
 40class CrossValidator:
 41    """
 42    Cross Validator Class for different types of data_sets
 43
 44    E.g. List -> [[Data], label]
 45         List[Dict] -> {"data": Data, "label": label}
 46         Huggingface Dataset Object -> Data(set="train", label = "label").select(idx)
 47    """
 48
 49    def __init__(self, dataset: Union[List, List[Dict], datasets.Dataset],
 50                 x_idx_label_or_attr: Union[str, int], y_idx_label_or_attr: Union[str, int],
 51                 cross_validator_type: [str, CrossValidatorTypes] = CrossValidatorTypes.Stratified,
 52                 seed=42, n_splits=5):
 53        # self.evaluator = evaluator
 54        self.cross_vali_fn = str_to_fn[cross_validator_type](n_splits=n_splits,
 55                                                             shuffle=True,
 56                                                             random_state=seed)
 57        self.dataset = dataset
 58        self.splits = []
 59
 60        self.x_label = x_idx_label_or_attr
 61        self.y_label = y_idx_label_or_attr
 62
 63        if self.dataset_type is None:
 64            self._determine_dataset_type()
 65            x, y = self.split_fn(x_idx_label_or_attr, y_idx_label_or_attr)
 66            self.splits = self.cross_vali_fn.split(x, y)
 67
 68    def _determine_dataset_type(self):
 69        if isinstance(self.dataset, list):
 70            if isinstance(self.dataset[0], dict):
 71                self.dataset_type = DatasetTypes.ListDict
 72                self.split_fn = self._split_dict
 73            elif isinstance(self.dataset[0], InputExample):
 74                self.dataset_type = DatasetTypes.ListInputExample
 75                self.split_fn = self._split_list
 76            else:
 77                self.dataset_type = DatasetTypes.List
 78                self.split_fn = self._split_list
 79        elif isinstance(self.dataset, datasets.Dataset):
 80            self.dataset_type = DatasetTypes.HuggingfaceDataset
 81            self.split_fn = self._split_dataset
 82        else:
 83            raise NotImplementedError("Unknown Dataset format")
 84
 85    def _split_list(self, *args, **kwargs):
 86        X = np.zeros(len(list(map(lambda k: k[self.x_label], self.dataset))))
 87        Y = map(lambda k: k[self.y_label], self.dataset)
 88
 89        return X, Y
 90
 91    def _split_dict(self, *args, **kwargs):
 92        X = np.zeros(len(list(map(lambda k: k[self.x_label], self.dataset))))
 93        Y = map(lambda k: k[self.y_label], self.dataset)
 94
 95        return X, Y
 96
 97    def _split_dataset(self, *args, **kwargs):
 98        # Rows data doesn't matter
 99        X = np.zeros(self.dataset.num_rows)
100        Y = self.dataset[self.y_label]
101
102        return X, Y
103
104    def get_fold(self, fold_num: int):
105        """
106
107        :param fold_num: Which fold to pick
108        :return:
109        """
110
111        split = self.splits[fold_num]
112
113        return {
114            "train_idxs": split[0],
115            "val_idxs": split[1]
116        }

Cross Validator Class for different types of data_sets

E.g. List -> [[Data], label] List[Dict] -> {"data": Data, "label": label} Huggingface Dataset Object -> Data(set="train", label = "label").select(idx)

CrossValidator( dataset: Union[List, List[Dict], datasets.arrow_dataset.Dataset], x_idx_label_or_attr: Union[str, int], y_idx_label_or_attr: Union[str, int], cross_validator_type: [<class 'str'>, <enum 'CrossValidatorTypes'>] = <CrossValidatorTypes.Stratified: 'StratifiedKFold'>, seed=42, n_splits=5)
49    def __init__(self, dataset: Union[List, List[Dict], datasets.Dataset],
50                 x_idx_label_or_attr: Union[str, int], y_idx_label_or_attr: Union[str, int],
51                 cross_validator_type: [str, CrossValidatorTypes] = CrossValidatorTypes.Stratified,
52                 seed=42, n_splits=5):
53        # self.evaluator = evaluator
54        self.cross_vali_fn = str_to_fn[cross_validator_type](n_splits=n_splits,
55                                                             shuffle=True,
56                                                             random_state=seed)
57        self.dataset = dataset
58        self.splits = []
59
60        self.x_label = x_idx_label_or_attr
61        self.y_label = y_idx_label_or_attr
62
63        if self.dataset_type is None:
64            self._determine_dataset_type()
65            x, y = self.split_fn(x_idx_label_or_attr, y_idx_label_or_attr)
66            self.splits = self.cross_vali_fn.split(x, y)
def get_fold(self, fold_num: int):
104    def get_fold(self, fold_num: int):
105        """
106
107        :param fold_num: Which fold to pick
108        :return:
109        """
110
111        split = self.splits[fold_num]
112
113        return {
114            "train_idxs": split[0],
115            "val_idxs": split[1]
116        }
Parameters
  • fold_num: Which fold to pick
Returns