J'ai trouvé cela dans la documentation de fastai. Ajouter des choses liées à la question que vous pouvez vérifier sur ici Lié à augs
class AlbumentationsTransform(RandTransform):
"Un gestionnaire de transformation pour plusieurs transformations `Albumentation`"
split_idx,order=None,2
def __init__(self, train_aug, valid_aug): store_attr()
def before_call(self, b, split_idx):
self.idx = split_idx
def encodes(self, img: PILImage):
if self.idx == 0:
aug_img = self.train_aug(image=np.array(img))['image']
else:
aug_img = self.valid_aug(image=np.array(img))['image']
return PILImage.create(aug_img)
Essentiellement, c'est tout ce dont vous avez besoin
Pour définir des augs différents, utilisez
def get_train_aug(): return albumentations.Compose([
albumentations.RandomResizedCrop(224,224),
albumentations.Transpose(p=0.5),
])
def get_valid_aug(): return albumentations.Compose([
albumentations.CenterCrop(224,224, p=1.),
albumentations.Resize(224,224)
], p=1.)
Et ensuite
item_tfms = [Resize(256), AlbumentationsTransform(get_train_aug(), get_valid_aug())]
dls = ImageDataLoaders.from_name_func(
path, get_image_files(path), valid_pct=0.2, seed=42,
label_func=is_cat, item_tfms=item_tfms)
dls.train.show_batch(max_n=4)
Profitez-en