Skip to content
Snippets Groups Projects
Unverified Commit 4f26fee9 authored by Patrick Labatut's avatar Patrick Labatut Committed by GitHub
Browse files

Datasets interface cleanup (#59)

Remove unused sample decoding interface in datasets.
parent f8969297
No related branches found
No related tags found
No related merge requests found
...@@ -20,9 +20,6 @@ class DatasetWithEnumeratedTargets(Dataset): ...@@ -20,9 +20,6 @@ class DatasetWithEnumeratedTargets(Dataset):
target = self._dataset.get_target(index) target = self._dataset.get_target(index)
return (index, target) return (index, target)
def get_sample_decoder(self, index: int) -> Any:
return self._dataset.get_sample_decoder(index)
def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]: def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]:
image, target = self._dataset[index] image, target = self._dataset[index]
target = index if target is None else target target = index if target is None else target
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from io import BytesIO from io import BytesIO
from typing import Any, Tuple from typing import Any
from PIL import Image from PIL import Image
...@@ -30,11 +30,3 @@ class TargetDecoder(Decoder): ...@@ -30,11 +30,3 @@ class TargetDecoder(Decoder):
def decode(self) -> Any: def decode(self) -> Any:
return self._target return self._target
class TupleDecoder(Decoder):
def __init__(self, *decoders: Decoder):
self._decoders: Tuple[Decoder, ...] = decoders
def decode(self) -> Any:
return (decoder.decode() for decoder in self._decoders)
...@@ -8,7 +8,7 @@ from typing import Any, Tuple ...@@ -8,7 +8,7 @@ from typing import Any, Tuple
from torchvision.datasets import VisionDataset from torchvision.datasets import VisionDataset
from .decoders import Decoder, TargetDecoder, ImageDataDecoder, TupleDecoder from .decoders import TargetDecoder, ImageDataDecoder
class ExtendedVisionDataset(VisionDataset): class ExtendedVisionDataset(VisionDataset):
...@@ -35,13 +35,5 @@ class ExtendedVisionDataset(VisionDataset): ...@@ -35,13 +35,5 @@ class ExtendedVisionDataset(VisionDataset):
return image, target return image, target
def get_sample_decoder(self, index: int) -> Decoder:
image_data = self.get_image_data(index)
target = self.get_target(index)
return TupleDecoder(
ImageDataDecoder(image_data),
TargetDecoder(target),
)
def __len__(self) -> int: def __len__(self) -> int:
raise NotImplementedError raise NotImplementedError
...@@ -22,7 +22,6 @@ from .extended import ExtendedVisionDataset ...@@ -22,7 +22,6 @@ from .extended import ExtendedVisionDataset
_Labels = int _Labels = int
_DEFAULT_MMAP_CACHE_SIZE = 16 # Warning: This can exhaust file descriptors _DEFAULT_MMAP_CACHE_SIZE = 16 # Warning: This can exhaust file descriptors
_IMAGES_SUBDIR_IMAGENET_21KP = "062717"
@dataclass @dataclass
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment