1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
| def unpickle(file): with open(file, 'rb') as fo: dict = pickle.load(fo, encoding='bytes') return dict
def write_img(data, coarse_labels, fine_labels, filenames, isTrain=True): res_data_dir = '/home/zj/data/decompress_cifar_100'
if isTrain: data_dir = os.path.join(res_data_dir, 'train') else: data_dir = os.path.join(res_data_dir, 'test')
if not os.path.exists(data_dir): os.mkdir(data_dir)
N = len(coarse_labels) for i in range(N): coarse_cate_dir = os.path.join(data_dir, str(coarse_labels[i])) if not os.path.exists(coarse_cate_dir): os.mkdir(coarse_cate_dir) fine_cate_dir = os.path.join(coarse_cate_dir, str(fine_labels[i])) if not os.path.exists(fine_cate_dir): os.mkdir(fine_cate_dir) img_path = os.path.join(fine_cate_dir, str(filenames[i], encoding='utf-8'))
img = data[i].reshape(3, 32, 32) img = np.transpose(img, (1, 2, 0)) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) cv2.imwrite(img_path, img)
def decompress_img(): data_list = ['test', 'train'] data_dir = '/home/zj/data/cifar-100-python/'
for item in data_list: data_dir = os.path.join(data_dir, item) di = unpickle(data_dir)
batch_label = str(di.get(b'batch_label'), encoding='utf-8') filenames = di.get(b'filenames') fine_labels = di.get(b'fine_labels') coarse_labels = di.get(b'coarse_labels') data = di.get(b'data')
if 'train' in batch_label: write_img(data, coarse_labels, fine_labels, filenames) else: write_img(data, coarse_labels, fine_labels, filenames, isTrain=False)
|