加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
util.py 4.54 KB
一键复制 编辑 原始数据 按行查看 历史
孙的空间 提交于 2024-11-01 19:30 . 98
import cv2 # 用于图像读取和处理
import numpy as np # 主要用于数值计算
import os # 文件和目录操作
from os.path import exists # 用于检测文件或目录是否存在
from imutils import paths # 获取文件路径
import pickle # 用于序列化和反序列化Python对象
from tqdm import tqdm # 用于在循环中添加进度条
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input # VGG16模型及预处理函数
from tensorflow.keras.preprocessing import image # 用于加载和处理图像
import logging # 日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # 配置日志格式
########## 步骤1:定义文件大小检查函数 ##########
def get_size(file):
"""
获取文件大小并转换为兆字节表示
"""
size_bytes = os.path.getsize(file) # 获取文件的字节大小
size_megabytes = size_bytes / (1024 * 1024) # 字节转兆字节
return size_megabytes # 返回大小
########## 步骤2:定义特征与标签构建主函数 ##########
def createXY(train_folder, dest_folder, method='vgg', batch_size=64):
"""
从图像数据创建特征 (X) 和标签 (y),并保存到文件。
支持 'vgg' 或 'flat' 两种模式。
"""
x_file_path = os.path.join(dest_folder, "X.pkl") # 特征文件路径
y_file_path = os.path.join(dest_folder, "y.pkl") # 标签文件路径
# 如果文件已存在,则直接加载数据
if exists(x_file_path) and exists(y_file_path):
logging.info("X和y已经存在,直接读取")
logging.info(f"X文件大小:{get_size(x_file_path):.2f}MB")
logging.info(f"y文件大小:{get_size(y_file_path):.2f}MB")
with open(x_file_path, 'rb') as f:
X = pickle.load(f) # 读取X
with open(y_file_path, 'rb') as f:
y = pickle.load(f) # 读取y
return X, y # 返回数据
########## 步骤3:读取图像路径并初始化X和y ##########
logging.info("读取所有图像,生成X和y")
image_paths = list(paths.list_images(train_folder)) # 获取图像路径
X = [] # 存储特征
y = [] # 存储标签
# 选择模型或扁平化处理
if method == 'vgg':
model = VGG16(weights='imagenet', include_top=False, pooling="max")
logging.info("完成构建 VGG16 模型")
elif method == 'flat':
model = None # 若为'flat'方法,无需模型
########## 步骤4:分批加载和处理图像 ##########
num_batches = len(image_paths) // batch_size + (1 if len(image_paths) % batch_size else 0)
for idx in tqdm(range(num_batches), desc="读取图像"):
batch_images = [] # 当前批次图像
batch_labels = [] # 当前批次标签
# 获取当前批次图像范围
start = idx * batch_size
end = min((idx + 1) * batch_size, len(image_paths))
# 加载和预处理当前批次的图像
for i in range(start, end):
image_path = image_paths[i] # 获取图像路径
if method == 'vgg':
img = image.load_img(image_path, target_size=(224, 224)) # 加载并调整大小
img = image.img_to_array(img) # 转换为数组
elif method == 'flat':
img = cv2.imread(image_path, 0) # 读取为灰度图
img = cv2.resize(img, (32, 32)) # 调整尺寸
batch_images.append(img) # 添加到批次列表
# 解析图像文件名中的标签
label = 1 if image_path.split(os.path.sep)[-1].split(".")[0] == 'dog' else 0
batch_labels.append(label) # 添加标签
# 转换批次图像为数组并处理
batch_images = np.array(batch_images)
if method == 'vgg':
batch_images = preprocess_input(batch_images)
batch_pixels = model.predict(batch_images, verbose=0) # VGG16特征
else:
batch_pixels = batch_images.reshape((batch_images.shape[0], -1)) # 展平
X.extend(batch_pixels) # 添加特征到X
y.extend(batch_labels) # 添加标签到y
########## 步骤5:保存数据集 ##########
logging.info(f"X.shape: {np.shape(X)}")
logging.info(f"y.shape: {np.shape(y)}")
with open(x_file_path, 'wb') as f:
pickle.dump(X, f) # 保存X
logging.info(f"X文件大小: {get_size(x_file_path)} MB")
with open(y_file_path, 'wb') as f:
pickle.dump(y, f) # 保存y
logging.info(f"y文件大小: {get_size(y_file_path)} MB")
return X, y # 返回构建的特征和标签
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化