ยท ship detection competition

Ship Detection - Part 3 (Model Training)

Training Yolov5 for ships detection

Previously we looked at handling outlies. These outliers would disrupt the objects if not split. In this section, we will see how to train yolov5 for detection of ships. We will be using the ultralytics yolov5.

Creating dataset.yaml

To train yolov5 we need to create a dataset.yaml file and pass it to the prompts.

This file should contain the relative path to the dataset. It should also refer to train, val and test split. Test split is optional. After that, It should mention the names and labels of different classes in the dataset.

You can refer to this link for more information.

The following works in our case:

dataset_path = "../ship-detection"
train_path = "train"
val_path = "val"
test_path = ""

classes = {0:"ship"}
names = ""
for k, v in classes.items():
    names += f"  {k}: {v}"
yaml_text = f"""
path: {dataset_path}
train: {train_path}
val: {val_path}
test: {test_path}

names:
{names}
"""
print(yaml_text)
path: ../ship-detection
train: train
val: val
test: 

names:
  0: ship
with open("dataset.yaml", "w") as file:
    file.write(yaml_text)

Training the model

We will now clone the repository of yolov5.

!git clone https://github.com/ultralytics/yolov5.git

Now we need to install all the dependencies.

!pip install yolov5/requirements.txt

Now we can start training.

!python yolov5/train.py --imgsz 3008 --batch 1 --epochs 50 \
    --data dataset.yaml --weights yolov5l.pt --workers 4

Here, The --imgsz is the image size. --batch is the batch size. --epochs is the number of epochs to train. --data referes to the dataset.yaml file we created earlier. --weights tells the script what yolo weights to use. --workers are for the number of workers used when training.

The output is too big so we are omitting that. But it will show you where the model was saved like this:

Results saved to yolov5/runs/train/exp2

Inference

import torch
import matplotlib.pyplot as plt
exp_id = "2"
resolution = 3008
model = torch.hub.load('ultralytics/yolov5', 'custom', path=f"yolov5/runs/train/exp{exp_id}/weights/best.pt")
model_result = model(f"ship-detection/val/images/6.png", size=resolution)
%matplotlib inline
plt.imshow(model_result.render()[0])
plt.show()

png

Back to Blog