深度学习实战 cifar数据集预处理技术分析

栏目: 数据库 · 发布时间: 4年前

内容简介:本文首发于微信公众号:"算法与编程之美",欢迎关注,及时了解更多此系列文章。cifar数据集是以cifar-10-python.tar.gz的压缩包格式存储在远程服务器,利用keras的get_file()方法下载压缩包并执行解压,解压后得到:

欢迎点击「算法与编程之美」↑关注我们!

本文首发于微信公众号:"算法与编程之美",欢迎关注,及时了解更多此系列文章。

cifar数据集是以cifar-10-python.tar.gz的压缩包格式存储在远程服务器,利用keras的get_file()方法下载压缩包并执行解压,解压后得到:

cifar-10-batches-py

├── batches.meta

├── data_batch_1

├── data_batch_2

├── data_batch_3

├── data_batch_4

├── data_batch_5

├── readme.html

└── test_batch

其中data_batch_[1..5]为训练集数据,test_batch为测试集数据。

def load_data():

"""Loads CIFAR10 dataset.

# Returns

Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.

"""

dirname = 'cifar-10-batches-py'

origin = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'

path = get_file(dirname, origin=origin, untar=True)

num_train_samples = 50000

x_train = np.empty((num_train_samples, 3, 32, 32), dtype='uint8')

y_train = np.empty((num_train_samples,), dtype='uint8')

for i in range(1, 6):

fpath = os.path.join(path, 'data_batch_' + str(i))

(x_train[(i - 1) * 10000: i * 10000, :, :, :],

y_train[(i - 1) * 10000: i * 10000]) = load_batch(fpath)

fpath = os.path.join(path, 'test_batch')

x_test, y_test = load_batch(fpath)

y_train = np.reshape(y_train, (len(y_train), 1))

y_test = np.reshape(y_test, (len(y_test), 1))

if K.image_data_format() == 'channels_last':

x_train = x_train.transpose(0, 2, 3, 1)

x_test = x_test.transpose(0, 2, 3, 1)

return (x_train, y_train), (x_test, y_test)

data_batch_i 存放了cifar的训练集数据,每个文件1万条数据,采用pickle的方式进行序列化数据,利用pickle.load()的方式加载文件并反序列化为之前的dict(),该字典中有’data’和’label’两个key,分别存放了数据和标签。

def load_batch(fpath, label_key='labels'):

"""Internal utility for parsing CIFAR data.

# Arguments

fpath: path the file to parse.

label_key: key for label data in the retrieve

dictionary.

# Returns

A tuple `(data, labels)`.

"""

with open(fpath, 'rb') as f:

if sys.version_info < (3,):

d = cPickle.load(f)

else:

d = cPickle.load(f, encoding='bytes')

# decode utf8

d_decoded = {}

for k, v in d.items():

d_decoded[k.decode('utf8')] = v

d = d_decoded

data = d['data']

labels = d[label_key]

data = data.reshape(data.shape[0], 3, 32, 32)

return data, labels

where2 go 团队

   

微信号:算法与编程之美          

深度学习实战 cifar数据集预处理技术分析

长按识别二维码关注我们!

温馨提示: 点击页面右下角 “写留言”发表评论,期待您的参与!期待您的转发!


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

深度探索C++对象模型

深度探索C++对象模型

斯坦利•B.李普曼 (Stanley B. Lippman) / 侯捷 / 电子工业出版社 / 2012-1 / 69.00元

作者Lippman参与设计了全世界第一套C++编译程序cfront,这本书就是一位伟大的C++编译程序设计者向你阐述他如何处理各种explicit(明确出现于C++程序代码中)和implicit(隐藏于程序代码背后)的C++语意。 本书专注于C++面向对象程序设计的底层机制,包括结构式语意、临时性对象的生成、封装、继承,以及虚拟——虚拟函数和虚拟继承。这本书让你知道:一旦你能够了解底层实现模......一起来看看 《深度探索C++对象模型》 这本书的介绍吧!

JS 压缩/解压工具
JS 压缩/解压工具

在线压缩/解压 JS 代码

RGB HSV 转换
RGB HSV 转换

RGB HSV 互转工具

HEX CMYK 转换工具
HEX CMYK 转换工具

HEX CMYK 互转工具