代码拉取完成,页面将自动刷新
import os
import random
import shutil
from config import TRAIN_SET_RATIO, TEST_SET_RATIO,dataset_dir
import tensorflow as tf
class SplitDataset():
def __init__(self, dataset_dir, saved_dataset_dir, train_ratio=TRAIN_SET_RATIO, test_ratio=TEST_SET_RATIO, show_progress=False):
self.dataset_dir = dataset_dir
self.saved_dataset_dir = saved_dataset_dir
self.saved_train_dir = saved_dataset_dir + "/train/"
self.saved_valid_dir = saved_dataset_dir + "/valid/"
self.saved_test_dir = saved_dataset_dir + "/test/"
self.train_ratio = train_ratio
self.test_radio = test_ratio
self.valid_ratio = 1 - train_ratio - test_ratio
self.train_file_path = []
self.valid_file_path = []
self.test_file_path = []
self.index_label_dict = {}
self.show_progress = show_progress
if not os.path.exists(self.saved_train_dir):
os.mkdir(self.saved_train_dir)
if not os.path.exists(self.saved_test_dir):
os.mkdir(self.saved_test_dir)
if not os.path.exists(self.saved_valid_dir):
os.mkdir(self.saved_valid_dir)
def __get_label_names(self):
label_names = []
for item in os.listdir(self.dataset_dir):
item_path = os.path.join(self.dataset_dir, item)
if os.path.isdir(item_path):
label_names.append(item)
return label_names
def __get_all_file_path(self):
all_file_path = []
index = 0
for file_type in self.__get_label_names():
self.index_label_dict[index] = file_type
index += 1
type_file_path = os.path.join(self.dataset_dir, file_type)
file_path = []
for file in os.listdir(type_file_path):
single_file_path = os.path.join(type_file_path, file)
file_path.append(single_file_path)
all_file_path.append(file_path)
return all_file_path
def __copy_files(self, type_path, type_saved_dir):
for item in type_path:
src_path_list = item[1]
dst_path = type_saved_dir + "%s/" % (item[0])
if not os.path.exists(dst_path):
os.mkdir(dst_path)
for src_path in src_path_list:
shutil.copy(src_path, dst_path)
if self.show_progress:
print("Copying file "+src_path+" to "+dst_path)
def __split_dataset(self):
all_file_paths = self.__get_all_file_path()
for index in range(len(all_file_paths)):
file_path_list = all_file_paths[index]
file_path_list_length = len(file_path_list)
random.shuffle(file_path_list)
train_num = int(file_path_list_length * self.train_ratio)
test_num = int(file_path_list_length * self.test_radio)
self.train_file_path.append([self.index_label_dict[index], file_path_list[: train_num]])
self.test_file_path.append([self.index_label_dict[index], file_path_list[train_num:train_num + test_num]])
self.valid_file_path.append([self.index_label_dict[index], file_path_list[train_num + test_num:]])
def start_splitting(self):
self.__split_dataset()
self.__copy_files(type_path=self.train_file_path, type_saved_dir=self.saved_train_dir)
self.__copy_files(type_path=self.valid_file_path, type_saved_dir=self.saved_valid_dir)
self.__copy_files(type_path=self.test_file_path, type_saved_dir=self.saved_test_dir)
if __name__ == '__main__':
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
split_dataset = SplitDataset(dataset_dir="data",
saved_dataset_dir=dataset_dir,
show_progress=True)
split_dataset.start_splitting()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。