Create Your First PyTorch Model

For those of you who don't know, my family owns a regenerative farm in Michigan. Regenerative farming uses a combination of techniques to improve soil health. Soil health is important as it also improves livestock health. In a sense, we're allowing the environment to operated how it would in nature. One of the main products of the farm is rotational grazed grass-fed beef cattle. Over the years the herd has grown in size. Starting with only about 10 animals to know around 50.

We've always named each cow and had a little back story to go with it. As the herd gets larger it's become harder to remember all their names. So my idea is to create a "Cow name predictor" from an image I can upload off my phone. For example, I snap a photo, upload to some web application, click a button to classify, and voila! I receive back the name of the cow and a predictive value or error.

Probably not the most useful application out there. Not like Cupcake (yes it's actually the name of one of our cows) is going to care if I get her name wrong.

If you would like to learn more about my family's farm, feel free to check out their website Sugar Creek Farms.

Tools we will use

I'll be using the following tools to get this started.

  • Jupyter Notebook
  • PyTorch
  • (& fastbook)
  • voila
  • Azure Search Key - used to request images from Bing
  • That's it!

If you haven't used fastai yet, I would start immediately. It's a python package that abstracts away some of the boilerplate code of PyTorch. There are many reasons why you should look into I'll keep it short and only list a couple of important points. The first is that fastai is more a of framework and approach to getting a PyTorch model setup. The tool itself sits on top of PyTorch. Much how Keras does for TensorFlow. It removes most of the boilerplate code that is necessary to get a PyTorch model up and running. The second point is that fastai has a lot of resources for getting started in AI and machine learning.


While my task may be worthwhile, I may have underestimated how easy this would be. Some considerations have come to mind. How am I going to get images of all the individual cows? Will I even have enough to train a model on?

After looking through my photos folder I've determined this is going to be hard. I don't even have pictures for all the cows. You don't need as much data as you may think to train a model. It's surprising how much you can get out of a model with a small dataset. That being said I don't even have 100 photos. So my goal to predict a cow's name isn't going to work today. I'll work on collecting that dataset to try again in the future :)

Instead, I'll change the goal a little, keeping it in the same vein, and generate a model that predicts the cows breed. I should have no issues getting enough data to do this. Here is what I'm thinking. I can create a Python list of breeds. I Googled common cattle breeds and got 'black angus','charolais','hereford', 'simmental', 'red angus', 'texas longhorn', 'holstein', 'limousin', 'highlands'. I can then loop through the images and use Bings Image Search API to request images of each.


Open up Jupyter Notebooks and start a new notebook. Then install fastai and PyTorch whichever way is easiest for you. I usually choose to use pip.

from fastbook import *
from import *

Data Gathering

For this next bit you will have signup for a free Microsoft Azure account. This will be the hardest part if you are following along and replicating the cow name predictor. Do a quick Google search to find out how to sign up and you will find a lot of resources. Once you have signed up for a free account navigate to the Azure Search dashboard. Find and grab your AZURE_SEARCH_KEY from the settings. You will need this to use the Azure/Bing image search API. Quick note on this service. It's free but limits you to 150 results per request and you can only request so many queries per second.

key = os.environ.get('AZURE_SEARCH_KEY', 'xxxxxxxxxxxxxxxxx')

The above line will create a new environment variable containing your AZURE_SEARCH_KEY.

Once you've set key, you can use a handy function from fastbook search_images_bing. This is a helper function to make an API request to the Azure image search API.

<function fastbook.search_images_bing(key, term, min_sz=128)>

TIP: You'll notice we didn't import this function. We imported everything from fastbook using a wild import or *. This is usually a bad idea; yet, in this case we are fine to use. The reason this is a bad ideas is that when python import * it imports everything. This even means that it import other imports--which bloats your code. Another little tip. If you're not sure where something is coming from or don't know how to use it. Click on that variable or function and hit shift + tab. This will bring a little information box about the function.

results = search_images_bing(key, 'angus cattle')
ims = results.attrgot('content_url')

As you can see from the output above, we've downloaded 150 of angus cattle from Bing! We will take a look at one below. There are a ton of other ways to get data out there. This may not be the best way when you build your model so do a quick Google search and see what's out there.

Below is an example of an image we grabbed so you can see it. You can also sort through your directory and see there is now an images folder.

dest = 'images/cattle.jpg'
download_url(ims[0], dest)
im =

Okay, so know let's see if we can create a list of breeds and grab images for each! I'll also create a path variable.

cattle_types = 'black angus','charolais','hereford', 'simmental', 'red angus', 'texas longhorn', 'holstein', 'limousin', 'highlands'
path = Path('cattle')
if not path.exists():
    for o in cattle_types:
        dest = (path/o)
        results = search_images_bing(key, f'{o} cattle')
        download_images(dest, urls=results.attrgot('content_url'))

The bit of code above is checking to see if there is a folder called cattle. If there isn't, it is creating it and downloading the search results. We will have search results for each of the cattle breeds listed above. Depending on your internet speed this should take a few minutes to run. At the end it will have downloaded 150 images per cattle breed.

fns = get_image_files(path)
(#1322) [Path('cattle/black angus/00000042.jpg'),Path('cattle/black angus/00000092.jpg'),Path('cattle/black angus/00000130.jpg'),Path('cattle/black angus/00000147.jpg'),Path('cattle/black angus/00000122.jpg'),Path('cattle/black angus/00000148.jpg'),Path('cattle/black angus/00000135.jpg'),Path('cattle/black angus/00000026.jpg'),Path('cattle/black angus/00000066.jpg'),Path('cattle/black angus/00000025.jpg')...]

Now that photos have been download and save in root_directory/cattle/[breed_name] let's see if are any bad images. Almost always you will get some images that are not coded correctly or are not images at all. So it's best to filter those out and throw them away assuming you have enough data.

Run the following and it will show which images are corrupt. You will see that we had 56 failed images. Take a look and see if you can't figure out why they failed. To remove failed images we can use the unlink method on Path. Unlink will make sure we don't try to send these failed images through our model.

failed = verify_images(fns)
(#56) [Path('cattle/black angus/00000147.jpg'),Path('cattle/black angus/00000011.jpg'),Path('cattle/black angus/00000044.jpg'),Path('cattle/black angus/00000145.JPG'),Path('cattle/black angus/00000056.JPG'),Path('cattle/black angus/00000139.jpg'),Path('cattle/black angus/00000032.jpg'),Path('cattle/black angus/00000043.jpg'),Path('cattle/black angus/00000144.jpg'),Path('cattle/black angus/00000116.jpg')...];  # Path.unlink removes the files from our Path object so that we won't use in our model.

Preparing Training and Testing Datasets

cattle = DataBlock(
    blocks=(ImageBlock, CategoryBlock), 
    splitter=RandomSplitter(valid_pct=0.2, seed=42),

The above code is running another useful function from fastai. This function takes in data and will return training and testing datasets.

The first argument, blocks, takes in two variables, the independent variable (ImageBlock) and the dependent variable (CategoryBlock). Or often call features and labels. We train our model on features while labels are the correct answer. For example if you're trying to predict someone's salary. Your features will be age, education, location, industry, etc. The label is the final choice, such as $65,000/yr.

Next we pass git_image_files, which is a list of path locations to the git_items argument. This is telling the dataloaders our images are located at.

splitter is used to randomly split that dataset up into training and testing dataset. It's extremely important to do this as early in the process as you can. This is to make sure you are not accidentally training on testing data, aka cheating. valid_pct is a % of images you want to reserve for testing on. In this case we are reserving 20% of the data to test on. The remaining 80% will be used to train our model. seed is useful in telling the computer where to start the random number generator. It allows us to 'randomly' get the same results again and again.

get_y is used to set what are labels should be. fastai provides a function called parent_label which grabs the name of the parent directory the image in and sets it as the label or dependent variable.

Lastly, we have item_tfms or "item transforms" is reshaping all the images to be 128 by 128 pixels. We will need all the images to be the exact same dimensions if we are going to be using matrix multiplication or tensors.

dls = cattle.dataloaders(path)

Run the following to show and validate the data from the dataloaders.

dls.valid.show_batch(max_n=10, nrows=2)

Data Augmentation

To reformat all photos to be the same size of 128x128 pixels we can use several different options. We have the ability to squish or stretch them. This usually leads to phots of cows with very odd proportions not seen in nature. We can pad images. Think wide screen movie view with the black bars on top and bottom of the screen. This option does distort the photo at all; yet, it does create a lot of empty image or space.

Usually the common practice is to select random crops of the photo and do this several times. This actually allows use to use an image several times which increasing our dataset. See the code below and an example batch to see what I mean.

This process is referred to as data augmentation. There are a lot of common techniques you can use here from rotation, warping, flipping, contrast changes and brightness changes.

cattle =
    item_tfms=RandomResizedCrop(224, min_scale=0.5),
dls = cattle.dataloaders(path)
dls.train.show_batch(max_n=8, nrows=2, unique=True)

Training Your Model

Now that we have clean and preprocessed (split and augmented) our data we are ready to train our model!

learn = cnn_learner(dls, resnet18, metrics=error_rate)
Downloading: "" to /home/user/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth

epoch train_loss valid_loss error_rate time
0 2.620505 0.873106 0.284585 00:13
epoch train_loss valid_loss error_rate time
0 1.087951 0.742601 0.233202 00:12
1 0.934089 0.628281 0.189723 00:13
2 0.764373 0.587155 0.189723 00:13
3 0.651530 0.581575 0.185771 00:14

At first glance we didn't do that great. We have an error rate of 0.185771. Let's take a look at the confusion matrix to visualize the results and see where we can improve our model.

interp = ClassificationInterpretation.from_learner(learn)

The rows represent all the cattle breeds we have in our dataset. The columns represent the images the model predicted as one of the cattle breeds. This means the diagonal represents the number of images that were classified/predicted correctly. While everything else is was predicted wrong.

It can be helpful to look and see where the errors are coming from. When I did this the first time I had only the breed 'angus' rather than 'black angus' and 'red angus'. By splitting these out the results have improved. We could do the exact same thing for 'herford' and 'simmental'. Both breeds often have but red and black colors. So by splitting these two up I would be willing to bet it would improve our model. Why don't you give it a try and see?

Because this is also a dataset that was gathered from online images there could be images that are categorized wrong. I also see some hand drawn images or cartoons that are causing issues. It's always best to take a look at the images the model got wrong or was least confident about.

interp.plot_top_losses(5, nrows=1)