Resizing and scaling images in python. And a filename-insensitive data loader.

Recently wrote this utility, to resize and scale a python image at the center. This way, I don't have to worry about how my images are being fed into an AI pipeline, if they all get scaled and cropped the same.

Additionally, while previously I used bash to rename files in a folder to sequential names such as 1.jpg , the current implementation ignores filenames, which saves me time when massaging the data.


from PIL import Image
import torch; torch.manual_seed(0)
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
from torch.utils.data import Dataset 
from torchvision import transforms
import numpy as np
import os
import imageio
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class ImgDataset(Dataset):
    def __init__(self, config=None, transform=None):
        self.transform = transform
        self.c = config
    def __len__(self):
        return len(self.img_names())
    def img_names(self):
        data_dir = self.c.data_dir
        out = [name for name in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, name))]
        return out
    def __getitem__(self, idx):
        data_dir = self.c.data_dir
        # print(f"+++ +++ Getting {idx}...")
        if torch.is_tensor(idx):
            idx = idx.tolist()
        trans = transforms.Compose([transforms.ToTensor()])
        img_path = f"{data_dir}/{self.img_names()[idx]}"
        img = Image.open(img_path).convert('RGB')
        ## either cut sides or top/bottom
        final_ratio = self.c.IMG_W / self.c.IMG_H
        w, h = img.size
        if w/h > final_ratio:
            ## too wide
            e = ( w - h*final_ratio )/2 ## excess
            img = img.crop((e, 0, w-e, h))
        else:
            ## too tall
            e = ( h - w/final_ratio )/2 ## excess
            img = img.crop((0, e, w, h-e))
        img.thumbnail((self.c.IMG_W, self.c.IMG_H))
        image = trans(img)
        label = self.img_names()[idx].split(".")[0]
        label = torch.tensor(label, dtype=torch.int32)
        return image, label
Related Articles

Please log in to post a comment: