split#
- Dataset.split(input_seed: int = 0, split_names: Sequence[str] = ('train', 'valid'), target_split_shares: Sequence[float] = (0.8, 0.2), keep_separate_groups: str | ContinuousGroup | Sequence[str | ContinuousGroup] = ('image_id',), keep_balanced_groups: str | ContinuousGroup | Sequence[str | ContinuousGroup] = ('category_id',), keep_balanced_groups_weights: Sequence[float] | None = None, inplace: bool = False, hist_cost_weight: float = 1, share_cost_weight: float = 1, earth_mover_regularization: float = 0) Self[source]#
Perform the split operation on annotations and images.
This algorithm works in 2 steps:
- divide the dataframe into atomic sub frames. Given the image and annotation
attributes that need to be kept separate, we can construct sub frame of elements that cannot be in different splits.
- Construct the split dataframes iteratively by trying to keep given column
values with a balanced repartition between splits, along with keeping split sizes as close to target share as possible. Each atomic sub-dataframe is routed to the split that minimize a cost function which try to optimize repartition targets.
Warning
if self.images and
self.annotationseach have a column with the same name, the column inself.imageswill be ignored. Make sure column names are mutually exclusive to avoid problems.See
pandas.split_dataframe()- Parameters:
input_seed – Seed used for shuffling sub dataframes before beginning step 2 of splitting algorithm. Defaults to 0.
split_names – Names of splits. Must be the same length as
split_shares. Defaults to (“train”, “valid”).target_split_shares – List of target relative size of each split. Must be the same length as
split_names, and will be normalized so that its sum is 1. Defaults to (0.8, 0.2).keep_separate_groups – columns or groups (see
group) in annotations or images DataFrame to keep separate. That is for a particular column or group, two rows with the same value cannot be in different splits. Note thatimage_idwill be added to that list, because split happen at the image level. Defaults to (“image_id”,).keep_balanced_groups – columns or groups (see
group) in annotations or images DataFrame to keep balanced. That is for a particular group, the distribution of values is the same between original DataFrame and its split, as much as possible. Defaults to (“category_id”,).keep_balanced_groups_weights – Importance of each group to keep balanced when computing histogram cost. If not None, must be a single float or the same size as
keep_separate_groups. Defaults to None.inplace – If set, will modify dataframes inplace. This can silently modify some objects (like Datasets) that use them but has a lower memory footprint. Defaults to False.
hist_cost_weight – importance of histogram cost for balanced groups. The higher, the more important the histogram cost will be for the decision of where to put each split. Defaults to 1.
share_cost_weight – importance of share cost for balanced groups. The higher, the more important the share cost will be for the decision of where to put each split. Defaults to 1.
earth_mover_regularization – Regularization parameter applied to sinkhorn’s algorithm during earth mover distance computation. See
lours.dataset.split.balanced_group.earth_mover_distance(). Defaults to 0.
- Returns:
new Dataset with the split column populated with the corresponding split names.
See also
More in-depth explanation in this tutorial
Example
>>> from lours.utils.doc_utils import dummy_dataset >>> example = dummy_dataset( ... 200, ... n_attribute_columns_images={"balanced": 10, "separate": 10}, ... split_names=None, ... seed=1, ... ) >>> example Dataset object containing 200 images and 2 objects Name : shake_effort_many Images root : care/suggest Images : width height relative_path type balanced separate id 0 955 488 determine/story.jpg .jpg send system 1 131 895 air/method.bmp .bmp note system 2 229 880 political/lead.jpg .jpg anything law 3 840 384 like/safe.bmp .bmp anything likely 4 953 668 suffer/set.jpeg .jpeg training attack .. ... ... ... ... ... ... 195 122 437 state/almost.tiff .tiff anything star 196 752 300 weight/tend.jpeg .jpeg could rest 197 554 228 remember/summer.png .png anything system 198 688 605 yet/though.png .png note number 199 243 227 describe/road.tiff .tiff end number [200 rows x 6 columns] Annotations : image_id category_str category_id ... box_y_min box_width box_height id ... 0 77 reach 22 ... 45.427512 40.116677 318.073851 1 137 marriage 15 ... 202.481384 435.389400 475.375279 [2 rows x 7 columns] Label map : {14: 'listen', 15: 'marriage', 22: 'reach'} >>> example.images["separate"].value_counts() separate star 27 likely 27 number 27 attack 22 rest 20 law 18 entire 17 enough 16 system 15 often 11 Name: count, dtype: int64 >>> splitted = example.split( ... keep_balanced_groups=["balanced"], keep_separate_groups=["separate"] ... ) Splitting annotations ... Separating input data into atomic chunks 1 chunks to distribute across 2 splits Splitting images ... Separating input data into atomic chunks 9 chunks to distribute across 2 splits >>> splitted Dataset object containing 200 images and 2 objects Name : shake_effort_many Images root : care/suggest Images : width height relative_path type split balanced separate id 0 955 488 determine/story.jpg .jpg train send system 1 131 895 air/method.bmp .bmp train note system 2 229 880 political/lead.jpg .jpg valid anything law 3 840 384 like/safe.bmp .bmp train anything likely 4 953 668 suffer/set.jpeg .jpeg train training attack .. ... ... ... ... ... ... ... 195 122 437 state/almost.tiff .tiff train anything star 196 752 300 weight/tend.jpeg .jpeg valid could rest 197 554 228 remember/summer.png .png train anything system 198 688 605 yet/though.png .png train note number 199 243 227 describe/road.tiff .tiff train end number [200 rows x 7 columns] Annotations : image_id category_str category_id ... box_y_min box_width box_height id ... 0 77 reach 22 ... 45.427512 40.116677 318.073851 1 137 marriage 15 ... 202.481384 435.389400 475.375279 [2 rows x 8 columns] Label map : {14: 'listen', 15: 'marriage', 22: 'reach'} >>> splitted.images.groupby("split")["separate"].value_counts() split separate train star 27 likely 27 number 27 attack 22 entire 17 enough 16 system 15 rest 0 law 0 often 0 valid rest 20 law 18 often 11 entire 0 star 0 attack 0 likely 0 system 0 enough 0 number 0 Name: count, dtype: int64 >>> splitted.images.groupby("split")["balanced"].value_counts() split balanced train could 21 coach 20 end 20 firm 17 send 16 anything 14 training 13 lead 10 note 10 region 10 valid could 8 send 8 note 6 firm 5 anything 5 training 4 region 4 end 4 coach 3 lead 2 Name: count, dtype: int64