In [1]:
from transformers import VisionEncoderDecoderModel, TrOCRProcessor
from PIL import Image
from tqdm import tqdm

import torch
import os
import glob as glob
import matplotlib.pyplot as plt
import warnings

In [2]:
warnings.filterwarnings('ignore')

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')

In [4]:
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-small-handwritten')
trained_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-small-handwritten').to(device)

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-small-handwritten and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
def read_and_show(image_path):
    """
    :param image_path: String, path to the input image.

    Returns:
        image: PIL Image.
    """
    image = Image.open(image_path).convert('RGB')
    return image

In [6]:
def ocr(image, processor, model):
    """
    :param image: PIL Image.
    :param processor: Huggingface OCR processor.
    :param model: Huggingface OCR model.

    Returns:
        generated_text: the OCR'd text string.
    """
    pixel_values = processor(image, return_tensors='pt').pixel_values.to(device)
    generated_ids = model.generate(pixel_values)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_text

In [7]:
def eval_new_data(
    data_path=None,
    num_samples=50,
    show_image=False
):
    image_paths = glob.glob(data_path)
    for i, image_path in tqdm(enumerate(image_paths), total=len(image_paths)):
        if i == num_samples and num_samples != -1:
            break
        image = read_and_show(image_path)
        text = ocr(image, processor, trained_model)

        # Save figure
        plt.figure(figsize=(7, 4))
        plt.title(text)
        plt.imshow(image)
        plt.axis('off')
        plt.savefig(os.path.join('pretrained_model_inference', image_path.split('/')[-1]))
        plt.close()  # Close the figure to free up memory

        # Show image if requested
        if show_image:
            plt.figure(figsize=(7, 4))
            plt.title(text)
            plt.imshow(image)
            plt.axis('off')
            plt.show()

In [8]:
os.makedirs('pretrained_model_inference', exist_ok=True)

eval_new_data(
    data_path=os.path.join('input/gnhk_dataset/test_processed/images', '*'),
    num_samples=-1,
    show_image=False
)

100%|██████████| 10066/10066 [08:18<00:00, 20.21it/s]
