做一个公司的门户网站多少钱,吉林省建设厅官方网站,深圳网络营销做什么的,天眼查询个人 企业查询言简意赅的讲解TensorFlow卷积神经网络#xff08;CNN#xff09;解决的痛点
项目概览
垃圾分类是实现可持续发展的重要环节#xff0c;本教程通过TensorFlow经典的卷积神经网络#xff08;CNN#xff09;示例#xff0c;带你从环境配置到单图推理全流程落地#xff1…言简意赅的讲解TensorFlow卷积神经网络CNN解决的痛点
项目概览
垃圾分类是实现可持续发展的重要环节本教程通过TensorFlow经典的卷积神经网络CNN示例带你从环境配置到单图推理全流程落地无需繁琐背景只讲关键步骤快速构建高效、可解释的自动化分类系统。如果读文章的同学想一键拥有和我一样的环境的话可以先部署Conda有疑问的话可以读之前文章零基础上手Conda安装、创建环境、管理依赖的完整指南
环境管理environment.yml 一键复现数据集准备下载链接与目录结构数据清洗自动删除损坏图片数据增强提升模型鲁棒性模型搭建与训练CNN 架构详解训练过程可视化Loss/Accuracy 曲线单图推理实时分类与可解释分析CNN vs. Transformer 对比架构选型指南 一、环境管理
在项目根目录中创建一个名为 environment.yml 的文件内容示例如下
name: tf_gpu
channels:- defaults- conda-forge
dependencies:- _openmp_mutex4.5- blas1.0- brotli-python1.0.9- bzip21.0.8- ca-certificates2025.4.26- contourpy1.3.1- cudatoolkit11.2.2- cudnn8.1.0.77- cycler0.11.0- expat2.7.1- fonttools4.55.3- freetype2.13.3- glib2.84.0- glib-tools2.84.0- gst-plugins-base1.24.7- gstreamer1.24.7- icc_rt2022.1.0- icu75.1- intel-openmp2023.2.0- joblib1.4.2- kiwisolver1.4.8- krb51.21.3- lcms22.17- lerc4.0.0- libblas3.9.0- libcblas3.9.0- libclang1320.1.7- libdeflate1.24- libffi3.4.4- libfreetype2.13.3- libfreetype62.13.3- libgcc15.1.0- libglib2.84.0- libgomp15.1.0- libhwloc2.11.2- libiconv1.18- libintl0.22.5- libintl-devel0.22.5- libjpeg-turbo3.1.0- liblapack3.9.0- liblzma5.8.1- liblzma-devel5.8.1- libogg1.3.5- libpng1.6.47- libsqlite3.50.1- libtiff4.7.0- libvorbis1.3.7- libwebp-base1.5.0- libwinpthread12.0.0.r4.gg4f2fc60ca- libxcb1.17.0- libxml22.13.8- libzlib1.3.1- matplotlib3.10.0- matplotlib-base3.10.0- mkl2023.2.0- mkl-service2.4.1- openjpeg2.5.3- openssl3.5.0- pcre210.44- pillow11.2.1- pip25.1- ply3.11- pthread-stubs0.4- pyparsing3.2.0- pyqt5.15.10- pyqt5-sip12.13.0- python3.10.16- python-dateutil2.9.0post0- python_abi3.10- qt-main5.15.15- scikit-learn1.6.1- setuptools78.1.1- sip6.7.12- six1.17.0- sqlite3.45.3- tbb2021.13.0- threadpoolctl3.5.0- tk8.6.13- tomli2.0.1- tornado6.5.1- tzdata2025b- ucrt10.0.22621.0- unicodedata215.1.0- vc14.42- vc14_runtime14.42.34438- vs2015_runtime14.42.34438- wheel0.45.1- xorg-libxau1.0.12- xorg-libxdmcp1.1.5- xz5.8.1- xz-tools5.8.1- zlib1.3.1- zstd1.5.7- pip:- absl-py2.3.0- astunparse1.6.3- cachetools5.5.2- certifi2025.4.26- charset-normalizer3.4.2- flatbuffers25.2.10- gast0.4.0- google-auth2.40.3- google-auth-oauthlib0.4.6- google-pasta0.2.0- grpcio1.73.0- h5py3.14.0- idna3.10- keras2.10.0- keras-preprocessing1.1.2- libclang18.1.1- markdown3.8- markupsafe3.0.2- numpy1.23.5- oauthlib3.2.2- opt-einsum3.4.0- packaging25.0- protobuf3.19.6- pyasn10.6.1- pyasn1-modules0.4.2- requests2.32.4- requests-oauthlib2.0.0- rsa4.9.1- scipy1.15.3- tensorboard2.10.1- tensorboard-data-server0.6.1- tensorboard-plugin-wit1.8.1- tensorflow2.10.0- tensorflow-estimator2.10.0- tensorflow-io-gcs-filesystem0.31.0- termcolor3.1.0- typing-extensions4.14.0- urllib32.4.0- werkzeug3.1.3- wrapt1.17.2
prefix: C:\Users\Wenhao\.conda\envs\tf_gpu一键创建环境 conda env create -f environment.yml
conda activate garbage_classify2. 数据集准备
2.1 下载与解压
来源阿里云天池【垃圾分类数据集】 https://tianchi.aliyun.com/dataset/138860
2.2 目录结构
project-root/
├── dataset/
│ ├── Harmful/ # 有害垃圾
│ ├── Kitchen/ # 厨余垃圾
│ ├── Other/ # 其他垃圾
│ └── Recyclable/ # 可回收垃圾
├── clean_data.py
├── train.py
├── visualize.py
├── predict.py
└── environment.yml3. 数据清洗
3.1 目的
自动剔除打不开或截断的图片避免训练中断。
3.2 实现
# clean_data.py
import os
from PIL import Image, ImageFile# 支持加载截断图
ImageFile.LOAD_TRUNCATED_IMAGES True
DATA_DIR dataset/
bad_images []for root, _, files in os.walk(DATA_DIR):for fname in files:path os.path.join(root, fname)try:with Image.open(path) as img:img.verify()except:bad_images.append(path)if bad_images:print(f删除 {len(bad_images)} 张损坏图片)for p in bad_images:os.remove(p)print( ✔, p)
else:print(✅ 未检测到损坏图片)python clean_data.py4. 数据增强
4.1 增强化技巧
几何变换旋转、平移、剪切、缩放颜色变换亮度、通道抖动翻转与填充水平翻转 边界反射
4.2 代码示例
# train.py 中的数据生成部分
from tensorflow.keras.preprocessing.image import ImageDataGeneratorIMAGE_SIZE (128, 128)
BATCH_SIZE 32datagen ImageDataGenerator(rescale1./255,validation_split0.2,rotation_range20,width_shift_range0.1,height_shift_range0.1,shear_range10,zoom_range0.2,brightness_range[0.8,1.2],channel_shift_range15,horizontal_flipTrue,fill_modereflect
)train_gen datagen.flow_from_directory(dataset/,target_sizeIMAGE_SIZE,batch_sizeBATCH_SIZE,class_modecategorical,subsettraining
)
val_gen datagen.flow_from_directory(dataset/,target_sizeIMAGE_SIZE,batch_sizeBATCH_SIZE,class_modecategorical,subsetvalidation
)5. 模型搭建与训练
5.1 模型架构
卷积层 池化层提取多层次特征批归一化稳定加速训练全局平均池化参数少、防过拟合全连接 Dropout分类输出
5.2 训练脚本
# train.py
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (Conv2D, BatchNormalization, MaxPooling2D,GlobalAveragePooling2D, Dense, Dropout
)
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateauprint(✅ GPU:, tf.config.list_physical_devices(GPU))num_classes train_gen.num_classes
model Sequential([Conv2D(32,3,activationrelu,input_shape(128,128,3)),BatchNormalization(), MaxPooling2D(),Conv2D(64,3,activationrelu),BatchNormalization(), MaxPooling2D(),Conv2D(128,3,activationrelu),BatchNormalization(), MaxPooling2D(),GlobalAveragePooling2D(),Dense(128,activationrelu),Dropout(0.5),Dense(num_classes,activationsoftmax),
])model.compile(optimizertf.keras.optimizers.Adam(1e-4),losscategorical_crossentropy,metrics[accuracy]
)
model.summary()callbacks [EarlyStopping(monitorval_loss, patience5, restore_best_weightsTrue),ReduceLROnPlateau(monitorval_loss, factor0.5, patience3)
]history model.fit(train_gen,validation_dataval_gen,epochs50,callbackscallbacks
)model.save(custom_garbage_classifier.h5)
print(✅ 模型保存至 custom_garbage_classifier.h5)6. 训练过程可视化
# visualize.py
import matplotlib.pyplot as plt# Loss 曲线
plt.plot(history.history[loss], labeltrain_loss)
plt.plot(history.history[val_loss], labelval_loss)
plt.title(Loss 曲线)
plt.legend()
plt.show()# Accuracy 曲线
plt.plot(history.history[accuracy], labeltrain_acc)
plt.plot(history.history[val_accuracy], labelval_acc)
plt.title(Accuracy 曲线)
plt.legend()
plt.show()训练过程完整代码
import os
from PIL import Image, ImageFile# 允许 Pillow 加载被截断的图片
ImageFile.LOAD_TRUNCATED_IMAGES True# 数据集路径
DATA_DIR dataset/# 第一自动清理所有损坏或截断的图片
bad_images []
for root, _, files in os.walk(DATA_DIR):for fname in files:path os.path.join(root, fname)try:with Image.open(path) as img:img.verify()except Exception:bad_images.append(path)if bad_images:print(fFound {len(bad_images)} bad images. Removing…)for p in bad_images:os.remove(p)print( Removed, p)
else:print(No corrupted images found.)# —— 下面是你的训练脚本 —— #import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout# 检查 GPU 是否可用
print(✅ GPU 设备列表, tf.config.list_physical_devices(GPU))# 参数
IMAGE_SIZE (128, 128)
BATCH_SIZE 32
EPOCHS 15# 数据增强 预处理
datagen ImageDataGenerator(rescale1./255,validation_split0.2,rotation_range15,width_shift_range0.1,height_shift_range0.1,zoom_range0.1,horizontal_flipTrue
)train_gen datagen.flow_from_directory(DATA_DIR,target_sizeIMAGE_SIZE,batch_sizeBATCH_SIZE,class_modecategorical,subsettraining
)val_gen datagen.flow_from_directory(DATA_DIR,target_sizeIMAGE_SIZE,batch_sizeBATCH_SIZE,class_modecategorical,subsetvalidation
)# 自定义 CNN 模型结构
model Sequential([Conv2D(32, (3, 3), activationrelu, input_shape(IMAGE_SIZE[0], IMAGE_SIZE[1], 3)),MaxPooling2D(2, 2),Conv2D(64, (3, 3), activationrelu),MaxPooling2D(2, 2),Conv2D(128, (3, 3), activationrelu),MaxPooling2D(2, 2),Flatten(),Dense(128, activationrelu),Dropout(0.5),Dense(train_gen.num_classes, activationsoftmax)
])# 编译模型
model.compile(optimizeradam,losscategorical_crossentropy,metrics[accuracy]
)# 模型结构
model.summary()# 增加 EarlyStopping防止过拟合
from tensorflow.keras.callbacks import EarlyStopping
early_stop EarlyStopping(monitorval_loss, patience3, restore_best_weightsTrue)# 模型训练
history model.fit(train_gen,validation_dataval_gen,epochsEPOCHS,callbacks[early_stop]
)# 模型保存
model.save(custom_garbage_classifier.h5)
print(✅ 模型训练完成并保存为 custom_garbage_classifier.h5) 7. 单图推理与可解释 AI
# predict.py
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_arraymodel load_model(custom_garbage_classifier.h5)
img_path evalImageSet/5.jpg
IMG_SIZE (128, 128)img load_img(img_path, target_sizeIMG_SIZE)
x img_to_array(img)/255.0
x np.expand_dims(x,0)probs model.predict(x)[0]
class_idx np.argmax(probs)class_indices {Harmful:0,Kitchen:1,Other:2,Recyclable:3}
labels {v:k for k,v in class_indices.items()}print(f▶ {img_path} → {labels[class_idx]} ({probs[class_idx]:.1%}))
print(各类别概率)
for i,p in enumerate(probs):print(f {labels[i]:12}: {p:.2%})plt.imshow(img)
plt.title(f{labels[class_idx]} ({probs[class_idx]:.1%}))
plt.axis(off)
plt.show()可选Grad-CAM 可视化关注区域。
推理过程完整代码
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array# 1. 加载模型
model load_model(custom_garbage_classifier.h5)# 2. 指定要分析的图片路径
img_path evalImageSet/5.jpg # 改成你自己的图片# 3. 载入并预处理
IMG_SIZE (128, 128)
img load_img(img_path, target_sizeIMG_SIZE)
x img_to_array(img) / 255.0 # 归一化到 [0,1]
x np.expand_dims(x, axis0) # 变成 (1,128,128,3)# 4. 预测
probs model.predict(x)[0] # 得到一个长度为类别数的向量
class_idx np.argmax(probs) # 预测的类别索引# 5. 反查类别名称
# 这里假设你有个 class_indices dict来自训练时的 generator
# 比如{glass: 0, paper: 1, plastic: 2, ...}
# 请替换成你的实际 mapping
class_indices {Harmful: 0, Kitchen: 1, Other: 2, Recyclable: 3}
labels {v:k for k,v in class_indices.items()}pred_label labels[class_idx]
pred_prob probs[class_idx]# 6. 输出结果
print(f▶ 分析图片{img_path})
print(f预测类别{pred_label}置信度{pred_prob:.4%})
print(\n各类别概率)
for idx, p in enumerate(probs):print(f {labels[idx]:8}: {p:.2%})# 7. 可选显示图片
plt.imshow(img)
plt.title(fPred: {pred_label} ({pred_prob:.1%}))
plt.axis(off)
plt.show()8. CNN 与 Transformer 对比
维度CNNTransformer核心模块Conv2D PoolingSelf-Attention Feed-Forward感受野随层级堆叠扩大单层即可实现全局参数共享卷积核在空间/时间上复用注意力权重在所有 token 对上共享位置敏感平移不变须显式位置编码原生顺序敏感 位置编码并行度高局部并行极高全局并行计算复杂度O(N·K²·C_out)O(N²·D)
共性底层张量运算、反向传播、优化器、正则化方法相同。选型依“局部 vs. 全局依赖”选择也可混合ViT、Conformer、CLIP。 通过上述内容你就已经基本理解了这个方法基础用法我也都有展示。如果你能融会贯通我相信你会很强
Best Wenhao (楠博万)