Ship Detection - Part 3 (Model Training)
Training Yolov5 for ships detection
Previously we looked at handling outlies. These outliers would disrupt the objects if not split. In this section, we will see how to train yolov5 for detection of ships. We will be using the ultralytics yolov5.
Creating dataset.yaml
To train yolov5
we need to create a dataset.yaml
file and pass it to the prompts.
This file should contain the relative path to the dataset. It should also refer to train
, val
and test
split. Test split is optional. After that, It should mention the names and labels of different classes in the dataset.
You can refer to this link for more information.
The following works in our case:
dataset_path = "../ship-detection"
train_path = "train"
val_path = "val"
test_path = ""
classes = {0:"ship"}
names = ""
for k, v in classes.items():
names += f" {k}: {v}"
yaml_text = f"""
path: {dataset_path}
train: {train_path}
val: {val_path}
test: {test_path}
names:
{names}
"""
print(yaml_text)
path: ../ship-detection
train: train
val: val
test:
names:
0: ship
with open("dataset.yaml", "w") as file:
file.write(yaml_text)
Training the model
We will now clone the repository of yolov5.
!git clone https://github.com/ultralytics/yolov5.git
Now we need to install all the dependencies.
!pip install yolov5/requirements.txt
Now we can start training.
!python yolov5/train.py --imgsz 3008 --batch 1 --epochs 50 \
--data dataset.yaml --weights yolov5l.pt --workers 4
Here, The --imgsz
is the image size. --batch
is the batch size. --epochs
is the number of epochs to train. --data
referes to the dataset.yaml file we created earlier. --weights
tells the script what yolo weights to use. --workers
are for the number of workers used when training.
The output is too big so we are omitting that. But it will show you where the model was saved like this:
Results saved to yolov5/runs/train/exp2
Inference
import torch
import matplotlib.pyplot as plt
exp_id = "2"
resolution = 3008
model = torch.hub.load('ultralytics/yolov5', 'custom', path=f"yolov5/runs/train/exp{exp_id}/weights/best.pt")
model_result = model(f"ship-detection/val/images/6.png", size=resolution)
%matplotlib inline
plt.imshow(model_result.render()[0])
plt.show()