加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
data_provider.py 1.27 KB
一键复制 编辑 原始数据 按行查看 历史
程序源码设计 提交于 2023-04-26 16:16 . version 1
import os
import tensorflow as tf
_FILE_PATTERN = 'FACE_%s.tfrecord'
dataset_dir = 'data'
reader = tf.TFRecordReader()
keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='raw'),
'image/class/label': tf.FixedLenFeature([1], tf.int64),
}
num_classes = 2
def get_data(split_name):
file_pattern = os.path.join(dataset_dir, _FILE_PATTERN % split_name)
filename_queue = tf.train.string_input_producer([file_pattern])
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,features = keys_to_features )
#image = tf.decode_raw(features['image/encoded'], tf.uint8)
image = tf.image.decode_png(features['image/encoded'])
#label = tf.cast(features['image/class/label'],tf.float32)
label = tf.one_hot(features['image/class/label'], num_classes)
label = tf.reshape(label, shape=(num_classes,))
print ("label:", label)
image = tf.image.convert_image_dtype(image, tf.float32)
image -= 0.5
image *= 2
image = tf.reshape(image, shape=(64*64,))
print (image, label)
return (image, label)
#test_image, test_label = get_data("test")
#print (test_image)
#print (test_label)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化