ResNet50

import subprocess
import sys

# 自动安装依赖
def install_packages():
    packages = ["torch", "torchvision", "pillow", "pandas", "openpyxl", "requests"]
    for pkg in packages:
        try:
            __import__(pkg)
        except ImportError:
            subprocess.check_call([sys.executable, "-m", "pip", "install", pkg])

install_packages()

# === 导入依赖 ===
import os
import shutil
import torch
from torchvision import models, transforms
from PIL import Image
import pandas as pd
from collections import defaultdict
import requests

# === 图像预处理(ImageNet)===
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# === 加载 ResNet50 模型 ===
model = models.resnet50(pretrained=True)
model.eval()

# === 加载类别标签 ===
# https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt

LABELS_PATH = r"g:\Dataset\Image\COCO_Annotations\imagenet_classes.txt"
with open(LABELS_PATH, "r", encoding="utf-8") as f:
    labels = f.read().strip().split('\n')


# === 单张图像分类 ===
def classify_image(image_path):
    image = Image.open(image_path).convert("RGB")
    input_tensor = transform(image).unsqueeze(0)
    with torch.no_grad():
        output = model(input_tensor)
        probs = torch.nn.functional.softmax(output[0], dim=0)
        top1_prob, top1_class = torch.topk(probs, 1)
    return labels[top1_class.item()], top1_prob.item()

# === 分类主流程 ===
def organize_images_by_class(src_dir, dst_dir):
    records = []
    class_counts = defaultdict(int)

    for fname in os.listdir(src_dir):
        if not fname.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.png')):
            continue

        img_path = os.path.join(src_dir, fname)
        predicted_class, prob = classify_image(img_path)
        class_folder = os.path.join(dst_dir, predicted_class)
        os.makedirs(class_folder, exist_ok=True)

        shutil.copy(img_path, os.path.join(class_folder, fname))
        records.append({
            "文件名": fname,
            "预测类别": predicted_class,
            "置信度": round(prob, 4)
        })
        class_counts[predicted_class] += 1

    # 保存结果到 Excel
    df = pd.DataFrame(records)
    count_df = pd.DataFrame([
        {"预测类别": k, "图片数量": v}
        for k, v in sorted(class_counts.items(), key=lambda x: -x[1])
    ])
    output_excel = os.path.join(dst_dir, "result.xlsx")
    with pd.ExcelWriter(output_excel) as writer:
        df.to_excel(writer, index=False, sheet_name="分类结果")
        count_df.to_excel(writer, index=False, sheet_name="类别汇总")

    print(f"\n✅ 分类完成,结果保存在 {output_excel}")
    print(f"📁 分类后的文件位于: {dst_dir}")

# === 命令行入口 ===
if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="对图像进行分类并复制到对应类别目录")
    parser.add_argument("--src", default=r"G:\Dataset\Image\images", help="原始图像路径")
    parser.add_argument("--dst", default=r"G:\Dataset\Image\tag", help="分类结果保存路径")
    args = parser.parse_args()

    organize_images_by_class(args.src, args.dst)
❤️ 转载文章请注明出处,谢谢!❤️