Detectron2-物件辨識模型訓練
- 取得連結
- X
- 以電子郵件傳送
- 其他應用程式
註冊數據集
import json
from detectron2.structures import BoxMode
def get_board_dicts(imgdir):
json_file = imgdir+"/dataset.json" #Fetch the json file
with open(json_file) as f:
dataset_dicts = json.load(f)
for i in dataset_dicts:
filename = i["file_name"]
i["file_name"] = imgdir+"/"+filename
for j in i["annotations"]:
j["bbox_mode"] = BoxMode.XYWH_ABS #Setting the required Box Mode
j["category_id"] = int(j["category_id"])
return dataset_dicts
from detectron2.data import DatasetCatalog, MetadataCatalog
#Registering the Dataset
for d in ["train", "val"]:
DatasetCatalog.register("boardetect_" + d, lambda d=d: get_board_dicts("Text_Detection_Dataset_COCO_Format/" + d))
MetadataCatalog.get("boardetect_" + d).set(thing_classes=["HINDI","ENGLISH","OTHER"])
board_metadata = MetadataCatalog.get("boardetect_train")
設置訓練的配置
from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg
import os
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")) #Get the basic model configuration from the model zoo
#Passing the Train and Validation sets
cfg.DATASETS.TRAIN = ("boardetect_train",)
cfg.DATASETS.TEST = ("boardetect_val",)
# Number of data loading threads
cfg.DATALOADER.NUM_WORKERS = 4
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml") # Let training initialize from model zoo
# Number of images per batch across all machines.
cfg.SOLVER.IMS_PER_BATCH = 4
cfg.SOLVER.BASE_LR = 0.0125 # pick a good LearningRate
cfg.SOLVER.MAX_ITER = 1500 #No. of iterations
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 256
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 3 # No. of classes = [HINDI, ENGLISH, OTHER]
cfg.TEST.EVAL_PERIOD = 500 # No. of iterations after which the Validation Set is evaluated.
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = CocoTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()
- 取得連結
- X
- 以電子郵件傳送
- 其他應用程式
留言
張貼留言