Click on the following items to reveal relevant code. These code blocks suffice for loading data and evaluating models.
- Salient Imagenet Dataloader
import os
import numpy as np
from PIL import Image
import pandas as pd
from torchvision import transforms as transforms
from torch.utils.data import Dataset
class SalientImageNet(Dataset):
def __init__(self, images_path, masks_path, class_index, feature_indices,
resize_size=256, crop_size=224):
self.transform = transforms.Compose([
transforms.Resize(resize_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor()
])
wordnet_dict = eval(open(os.path.join(masks_path, 'wordnet_dict.py')).read())
wordnet_id = wordnet_dict[class_index]
self.images_path = os.path.join(images_path, 'train', wordnet_id)
self.masks_path = os.path.join(masks_path, wordnet_id)
image_names_file = os.path.join(self.masks_path, 'image_names_map.csv')
image_names_df = pd.read_csv(image_names_file)
image_names = []
feature_indices_dict = defaultdict(list)
for feature_index in feature_indices:
image_names_feature = image_names_df[str(feature_index)].to_numpy()
for i, image_name in enumerate(image_names_feature):
image_names.append(image_name)
feature_indices_dict[image_name].append(feature_index)
self.image_names = np.unique(np.array(image_names))
self.feature_indices_dict = feature_indices_dict
def __len__(self):
return len(self.image_names)
def __getitem__(self, index):
image_name = self.image_names[index]
curr_image_path = os.path.join(self.images_path, image_name + '.JPEG')
image = Image.open(curr_image_path).convert("RGB")
image_tensor = self.transform(image)
feature_indices = self.feature_indices_dict[image_name]
all_mask = np.zeros(image_tensor.shape[1:])
for feature_index in feature_indices:
curr_mask_path = os.path.join(self.masks_path, 'feature_' + str(feature_index), image_name + '.JPEG')
mask = np.asarray(Image.open(curr_mask_path))
mask = (mask/255.)
all_mask = np.maximum(all_mask, mask)
all_mask = np.uint8(all_mask * 255)
all_mask = Image.fromarray(all_mask)
mask_tensor = self.transform(all_mask)
return image_tensor, mask_tensor
- Computing core and spurious accuracy
def core_spur_accuracy(dataset, model, core=True, noise_sigma=0.25, num_trials=5, apply_norm=True):
'''
Core regions are taken to be dilated core masks, and spurious regions are 1-dilated core masks
Use Salient Imagenet test set for 'dset', or any dataset with soft segmentation masks for core regions.
Returns overall core and spurious accuracy, as well as per class metrics.
'''
loader = torch.utils.data.DataLoader(dset, batch_size=128, shuffle=True, num_workers=16, pin_memory=True)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
cnt_by_class = dict({i:0 for i in range(1000)})
core_cc_by_class, spur_cc_by_class = dict({i:0 for i in range(1000)}), dict({i:0 for i in range(1000)}),
for dat in tqdm(loader):
imgs, masks, labels = [x.cuda() for x in dat]
idx_with_masks = (masks.flatten(1).sum(1) != 0)
imgs, masks, labels = [x[idx_with_masks] for x in [imgs, masks, labels]]
masks = dilate_erode_fast(masks)
for trial in range(num_trials):
noise = torch.randn_like(imgs, device=imgs.device) * noise_sigma
noisy_core, noisy_spur = [torch.clamp(imgs + (x * noise), 0, 1) for x in [masks, 1-masks]]
if apply_norm:
noisy_core, noisy_spur = [normalize(x) for x in [noisy_core, noisy_spur]]
noisy_core_preds, noisy_spur_preds = [model(x).argmax(1) for x in [noisy_core, noisy_spur]]
for y in np.unique(labels.cpu().numpy()):
core_cc_by_class[y] += (noisy_spur_preds[labels == y] == y).sum().item()
spur_cc_by_class[y] += (noisy_core_preds[labels == y] == y).sum().item()
cnt_by_class[y] += (labels == y).sum().item()
total_cnt, total_core_cc, total_spur_cc = 0, 0, 0
core_acc_by_class, spur_acc_by_class = dict(), dict()
for c in cnt_by_class:
if cnt_by_class[c] == 0:
continue
total_core_cc += core_cc_by_class[c]
total_spur_cc += spur_cc_by_class[c]
total_cnt += cnt_by_class[c]
core_acc_by_class[c] = core_cc_by_class[c] / cnt_by_class[c]
spur_acc_by_class[c] = spur_cc_by_class[c] / cnt_by_class[c]
core_acc, spur_acc = [100.*np.average(list(x.values())) for x in [core_acc_by_class, spur_acc_by_class]]
return core_acc, spur_acc, core_acc_by_class, spur_acc_by_class
- Computing relative core sensitivity
def rel_score(core_acc, spur_acc):
'''
Computes relative core sensitivity for scalar values core_acc and spur_acc
'''
avg = 0.5*(core_acc+spur_acc)
return 0 if (avg == 1 or avg == 0) else (core_acc - spur_acc) / (2*min(avg, 1-avg))
- Dilation
def dilate_erode(masks, dilate=True, iterations=15, kernel=5):
''' Dilate or erode tensor of soft segmentation masks'''
assert kernel % 2 == 1
half_k = kernel // 2
batch_size, _, side_len, _ = masks.shape
out = masks[:,0,:,:].clone()
padded = torch.zeros(batch_size, side_len+2*half_k, side_len+2*half_k, device=masks.device)
if not dilate:
padded = 1 + padded
for itr in range(iterations):
all_padded = []
centered = padded.clone()
centered[:, half_k:half_k+side_len, half_k:half_k+side_len]; all_padded.append(centered)
for j in range(1, half_k+1):
left, right, up, down = [padded.clone() for _ in range(4)]
left[:, half_k-j:half_k-j+side_len, half_k:half_k+side_len] = out; all_padded.append(left)
right[:, half_k+j:half_k+j+side_len, half_k:half_k+side_len] = out; all_padded.append(right)
up[:, half_k:half_k+side_len, half_k+j:half_k+j+side_len] = out; all_padded.append(up)
down[:, half_k:half_k+side_len, half_k-j:half_k-j+side_len] = out; all_padded.append(down)
all_padded = torch.stack(all_padded)
out = torch.max(all_padded, dim=0)[0] if dilate else torch.min(all_padded, dim=0)[0]
out = out[:, half_k:half_k+side_len, half_k:half_k+side_len]
out = torch.stack([out, out, out], dim=1)
out = out / torch.max(out)
return out