DEMO 4-1: Handwritten Digit Recognition App

## Train a random forest model
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import joblib
import numpy as np

path='./data/mnist.npz'
with np.load(path, allow_pickle=True) as f:
    x_train, y_train = f["x_train"], f["y_train"]
    x_test, y_test = f["x_test"], f["y_test"]
    
print(x_train.shape,y_train.shape)
clf = RandomForestClassifier(n_estimators=200, random_state=42)
clf.fit(x_train, y_train)
y_pred = clf.predict(x_test)
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)
joblib.dump(clf, './data/random_forest_model.pkl')
## build the app
import gradio as gr
import joblib
import numpy as np

# Load a pre-trained Random Forest model
model = joblib.load('./data/random_forest_model.pkl')
## you can download the model file from the huggingface link:
# import requests
# url = "https://huggingface.co/spaces/JunchuanYu/handwritten_recognition/resolve/main/data/random_forest_model.pkl"
# response = requests.get(url)
# with open("./random_forest_model.pkl", "wb") as file:
#     file.write(response.content)
# print("Download completed")

# Function to predict the digit
def predict_minist(image):
    normalized = image['composite'][:, :, -1]
    flattened = normalized.reshape(1, 784)
    prediction = model.predict(flattened)
    print(normalized.shape, np.max(normalized), prediction[0])
    return prediction[0]

with gr.Blocks(theme="soft") as demo:
    gr.Markdown("""
        <center> 
        <h1>Handwritten Digit Recognition</h1>
        <b>jason.yu.mail@qq.com 📧</b>
        </center>
        """)  
    gr.Markdown("Draw a digit and the model will predict the digit. Please draw the digit in the center of the canvas")
    with gr.Row():
        outtext = gr.Textbox(label="Prediction")
    with gr.Row():
        inputimg = gr.ImageMask(image_mode="RGBA", crop_size=(28,28))

    inputimg.change(predict_minist, inputimg, outtext)
# demo.launch()
demo.launch(height=550,width="100%",show_api=False)