Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Training a Custom CNN Image Classifier with TensorFlow.NET in C#

Tech 1

TensorFlow.NET provides a TensorFlow-compatible API for .NET Standard that closely mirrors the Python experience while integrating naturally with the SciSharp stack (NumSharp, SharpCV, Keras.NET, etc.). The example below builds and trains a compact CNN for grayscale image classification entirely in C#, targeting local datasets and running on CPU or GPU.

Project scope

  • Task: single-chraacter OCR for industrial printed characters
  • Classes: 3 (X/Y/Z)
  • Image shape: 64×64, 1 channel
  • Dataset split:
    • train: X/Y/Z = 384/384/384
    • validation: X/Y/Z = 96/96/96
    • test: X/Y/Z = 96/96/96
  • Augmentations applied offline: random flip, pan, scale, mirroring

The pipeline crops single characters with OpenCV upstream and feeds normalized 64×64 patches into the CNN. In application logic, per-character predictions can be concatenated into strings.

Enviroment

  • .NET: .NET Framework 4.7.2+ or .NET Core 2.2+
  • CPU: Any CPU / x64
  • GPU: CUDA and cuDNN properly installed and on PATH (e.g., CUDA 10.1, cuDNN 7.5)

Dependencies

Install from NuGet (or equivalent):

<ItemGroup>
  <PackageReference Include="TensorFlow.NET" Version="0.14.0" />
  <PackageReference Include="SciSharp.TensorFlow.Redist" Version="1.15.0" />
  <PackageReference Include="NumSharp" Version="0.30.0" />
  <PackageReference Include="SharpCV" Version="0.2.0" />
  <PackageReference Include="System.Drawing.Common" Version="4.7.0" />
  <PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
  <PackageReference Include="SharpZipLib" Version="1.2.0" />
  <PackageReference Include="Colorful.Console" Version="1.2.9" />
</ItemGroup>

Namespaces used:

using System;
using System.IO;
using System.Linq;
using System.Collections;
using System.Collections.Generic;
using System.Collections.Concurrent;
using System.Diagnostics;
using System.Threading.Tasks;
using NumSharp;
using SharpCV;
using Tensorflow;
using static Tensorflow.Binding;
using static SharpCV.Binding;

Program flow

public bool Run()
{
    LoadData();
    CreateGraph();

    using (var sess = tf.Session())
    {
        TrainModel(sess);
        Evaluate(sess);
    }

    EmitPerSampleReport();
    return testAccuracy >= 0.98f;
}

Data access

Download and extract (optional)

// Example: download and unzip to a local folder
var zipUrl = "https://github.com/SciSharp/SciSharp-Stack-Examples/raw/master/data/data_CnnInYourOwnData.zip";
var root = Name; // dataset root folder
Directory.CreateDirectory(root);
var zipPath = Path.Combine(root, "data_CnnInYourOwnData.zip");

using (var wc = new System.Net.WebClient())
    wc.DownloadFile(zipUrl, zipPath);

ICSharpCode.SharpZipLib.Zip.FastZip fz = new ICSharpCode.SharpZipLib.Zip.FastZip();
fz.ExtractZip(zipPath, root, null);

Label dictionary from folders

private void BuildLabelMap(string baseDir)
{
    var subdirs = Directory.GetDirectories(baseDir, "*", SearchOption.TopDirectoryOnly);
    Dict_Label = new Dictionary<long, string>();

    for (int i = 0; i < subdirs.Length; i++)
    {
        var label = Path.GetFileName(subdirs[i]);
        Dict_Label[i] = label;
        print($"{i} : {label}");
    }
    n_classes = Dict_Label.Count;
}

Split lists and labels

ArrayFileName_Train = Directory.GetFiles(Path.Combine(Name, "train"), "*.*", SearchOption.AllDirectories);
ArrayLabel_Train     = PathsToLabels(ArrayFileName_Train);

ArrayFileName_Validation = Directory.GetFiles(Path.Combine(Name, "validation"), "*.*", SearchOption.AllDirectories);
ArrayLabel_Validation    = PathsToLabels(ArrayFileName_Validation);

ArrayFileName_Test = Directory.GetFiles(Path.Combine(Name, "test"), "*.*", SearchOption.AllDirectories);
ArrayLabel_Test    = PathsToLabels(ArrayFileName_Test);

private long[] PathsToLabels(string[] paths)
{
    var labels = new long[paths.Length];
    for (int i = 0; i < paths.Length; i++)
    {
        var folder = Directory.GetParent(paths[i]).Name;
        labels[i] = Dict_Label.Single(p => p.Value == folder).Key;
    }
    return labels;
}

Shuffle aligned arrays (Fisher–Yates)

public (string[] paths, long[] labels) ShuffleAligned(string[] paths, long[] labels)
{
    var rng = new Random();
    for (int i = paths.Length - 1; i > 0; i--)
    {
        int j = rng.Next(i + 1);
        (paths[i],  paths[j])  = (paths[j],  paths[i]);
        (labels[i], labels[j]) = (labels[j], labels[i]);
    }
    print($"shuffled {paths.Length} samples");
    return (paths, labels);
}

Preload validation/test tensors

private void PreloadEvalTensors()
{
    y_valid = np.eye(Dict_Label.Count)[new NDArray(ArrayLabel_Validation)];
    y_test  = np.eye(Dict_Label.Count)[new NDArray(ArrayLabel_Test)];
    print("Loaded evaluation labels");

    x_valid = np.zeros((ArrayFileName_Validation.Length, img_h, img_w, n_channels));
    x_test  = np.zeros((ArrayFileName_Test.Length,       img_h, img_w, n_channels));

    LoadImagesInto(x_valid, ArrayFileName_Validation, "validation");
    LoadImagesInto(x_test,  ArrayFileName_Test,       "test");
}

private void LoadImagesInto(NDArray dst, string[] srcPaths, string tag)
{
    for (int i = 0; i < srcPaths.Length; i++)
    {
        dst[i] = ReadAndNormalizeImage(srcPaths[i]);
        if ((i + 1) % 32 == 0) Console.Write(".");
    }
    Console.WriteLine();
    Console.WriteLine($"Loaded {tag} images");
}

private NDArray ReadAndNormalizeImage(string path)
{
    using (var graph = tf.Graph().as_default())
    {
        var file = tf.read_file(path);
        var img  = tf.image.decode_jpeg(file, channels: n_channels);
        var f32  = tf.cast(img, tf.float32);
        var e    = tf.expand_dims(f32, 0);
        var sz   = tf.constant(new int[] { img_h, img_w });
        var res  = tf.image.resize_bicubic(e, sz);
        var norm = tf.divide(tf.subtract(res, new float[] { img_mean }), new float[] { img_std });

        using (var sess = tf.Session(graph))
            return sess.run(norm);
    }
}

Graph construction

Two conv+pool blocks, followed by a dense hidden layer and a logits layer. Training uses cross-entropy with softmax, with a step-wise learning-rate schedule.

public Graph CreateGraph()
{
    var graph = new Graph().as_default();

    tf_with(tf.name_scope("Inputs"), () =>
    {
        features = tf.placeholder(tf.float32, shape: (-1, img_h, img_w, n_channels), name: "features");
        targets  = tf.placeholder(tf.float32, shape: (-1, n_classes),           name: "targets");
    });

    var c1   = Conv2DLayer(features, filter_size1, num_filters1, stride1, name: "Conv1");
    var p1   = MaxPool(c1, 2, 2, name: "Pool1");
    var c2   = Conv2DLayer(p1, filter_size2, num_filters2, stride2, name: "Conv2");
    var p2   = MaxPool(c2, 2, 2, name: "Pool2");
    var flat = Flatten(p2);
    var h1   = Dense(flat, hidden_units: h1, name: "Dense1", relu: true);
    var logits = Dense(h1, hidden_units: n_classes, name: "Logits", relu: false);

    tf.constant(img_h, name: "img_h");
    tf.constant(img_w, name: "img_w");
    tf.constant(img_mean, name: "img_mean");
    tf.constant(img_std,  name: "img_std");
    tf.constant(n_channels, name: "img_channels");

    globalStep   = tf.Variable(0, trainable: false, name: "global_step");
    lr           = tf.Variable(learning_rate_base, name: "learning_rate");

    tf_with(tf.variable_scope("ImagePipeline"), () =>
    {
        // Accepts raw bytes or uint8 NDArray; used by the async batch loader
        imageInput = tf.placeholder(tf.@byte, name: "image_bytes_or_u8");
        var cast   = tf.cast(imageInput, tf.float32);
        var expand = tf.expand_dims(cast, 0);
        var rs     = tf.image.resize_bicubic(expand, tf.constant(new int[] { img_h, img_w }));
        normalized = tf.identity(tf.divide(tf.subtract(rs, new float[] { img_mean }), new float[] { img_std }), name: "normalized");
    });

    tf_with(tf.variable_scope("Train"), () =>
    {
        tf_with(tf.variable_scope("Loss"), () =>
        {
            loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels: targets, logits: logits), name: "loss");
        });

        tf_with(tf.variable_scope("Optimizer"), () =>
        {
            optimizer = tf.train.AdamOptimizer(learning_rate: lr, name: "adam").minimize(loss, global_step: globalStep);
        });

        tf_with(tf.variable_scope("Metrics"), () =>
        {
            var pred = tf.argmax(logits, 1);
            var tru  = tf.argmax(targets, 1);
            var ok   = tf.equal(pred, tru);
            accuracy = tf.reduce_mean(tf.cast(ok, tf.float32), name: "accuracy");
        });

        tf_with(tf.variable_scope("Outputs"), () =>
        {
            predictedClass = tf.argmax(logits, axis: 1, name: "class_id");
            probabilities  = tf.nn.softmax(logits, axis: 1, name: "probs");
        });
    });

    return graph;
}

private Tensor Conv2DLayer(Tensor x, int k, int filters, int stride, string name)
{
    return tf_with(tf.variable_scope(name), () =>
    {
        var inCh = x.shape[x.NDims - 1];
        var w = WeightVar("W", new[] { k, k, inCh, filters });
        var b = BiasVar("b", new[] { filters });
        var y = tf.nn.conv2d(x, w, strides: new[] { 1, stride, stride, 1 }, padding: "SAME");
        y = tf.nn.bias_add(y, b);
        return tf.nn.relu(y);
    });
}

private Tensor MaxPool(Tensor x, int k, int stride, string name)
{
    return tf.nn.max_pool(x, ksize: new[] { 1, k, k, 1 }, strides: new[] { 1, stride, stride, 1 }, padding: "SAME", name: name);
}

private Tensor Flatten(Tensor x)
{
    return tf_with(tf.variable_scope("Flatten"), () =>
    {
        var dims = x.TensorShape;
        var n = dims[new Slice(1, 4)].size;
        return tf.reshape(x, new[] { -1, n });
    });
}

private Tensor Dense(Tensor x, int hidden_units, string name, bool relu)
{
    return tf_with(tf.variable_scope(name), () =>
    {
        var inDim = x.shape[1];
        var w = WeightVar("W", new[] { inDim, hidden_units });
        var b = BiasVar("b", new[] { hidden_units });
        var y = tf.matmul(x, w) + b;
        return relu ? tf.nn.relu(y) : y;
    });
}

private RefVariable WeightVar(string name, int[] shape)
{
    var init = tf.truncated_normal_initializer(stddev: 0.01f);
    return tf.get_variable(name, dtype: tf.float32, shape: shape, initializer: init);
}

private RefVariable BiasVar(string name, int[] shape)
{
    return tf.get_variable(name, dtype: tf.float32, initializer: tf.constant(0f, shape: shape));
}

Traniing and checkpointing

  • Batches are formed asynchronously with BlockingCollection to overlap disk IO and GPU/CPU compute
  • Learning rate decays by a factor at fixed epoch steps (clamped by a minimum)
  • Checkpoints can save the best model or every epoch
public void TrainModel(Session sess)
{
    int stepsPerEpoch = ArrayLabel_Train.Length / batch_size;

    sess.run(tf.global_variables_initializer());
    var saver = tf.train.Saver(tf.global_variables(), max_to_keep: 10);

    path_model = Path.Combine(Name, "MODEL");
    Directory.CreateDirectory(path_model);

    float valLoss = float.MaxValue;
    float valAcc  = 0f;
    var timer = new Stopwatch();

    for (int epoch = 0; epoch < epochs; epoch++)
    {
        print($"Epoch {epoch + 1}");

        (ArrayFileName_Train, ArrayLabel_Train) = ShuffleAligned(ArrayFileName_Train, ArrayLabel_Train);
        y_train = np.eye(Dict_Label.Count)[new NDArray(ArrayLabel_Train)];

        if (learning_rate_step > 0 && epoch > 0 && epoch % learning_rate_step == 0)
        {
            learning_rate_base = Math.Max(learning_rate_min, learning_rate_base * learning_rate_decay);
            sess.run(tf.assign(lr, learning_rate_base));
        }

        var queue = new BlockingCollection<(NDArray X, NDArray Y, int iter)>(TrainQueueCapa);

        Task.Run(() =>
        {
            for (int iter = 0; iter < stepsPerEpoch; iter++)
            {
                int start = iter * batch_size;
                int end   = (iter + 1) * batch_size;
                var (bx, by) = NextBatch(sess, ArrayFileName_Train, y_train, start, end);
                queue.Add((bx, by, iter));
            }
            queue.CompleteAdding();
        });

        timer.Restart();
        foreach (var it in queue.GetConsumingEnumerable())
        {
            sess.run(optimizer, (features, it.X), (targets, it.Y));

            if (it.iter % display_freq == 0)
            {
                var res = sess.run(new[] { loss, accuracy }, (features, it.X), (targets, it.Y));
                print($"iter {it.iter:000}: loss={res[0]:0.0000}, acc={res[1]:P} {timer.ElapsedMilliseconds}ms");
                timer.Restart();
            }
        }

        var eval = sess.run((loss, accuracy), (features, x_valid), (targets, y_valid));
        valLoss = eval.Item1;
        valAcc  = eval.Item2;
        print("---------------------------------------------------------");
        print($"global_step: {sess.run(globalStep)}, lr: {sess.run(lr)}, val_loss: {valLoss:0.0000}, val_acc: {valAcc:P}");
        print("---------------------------------------------------------");

        if (saveBestOnly)
        {
            if (valAcc > bestValAcc)
            {
                bestValAcc = valAcc;
                saver.save(sess, Path.Combine(path_model, "CNN_Best"));
                print("Saved best checkpoint.");
            }
        }
        else
        {
            var ckpt = Path.Combine(path_model, $"CNN_Epoch_{epoch}_Loss_{valLoss}_Acc_{valAcc}");
            saver.save(sess, ckpt);
            print("Saved epoch checkpoint.");
        }
    }

    SaveLabelMap(Path.Combine(path_model, "labels.txt"), Dict_Label);
}

private void SaveLabelMap(string path, Dictionary<long, string> map)
{
    using var sw = new StreamWriter(new FileStream(path, FileMode.Create));
    foreach (var kv in map)
        sw.WriteLine($"{kv.Key},{kv.Value}");
    print("Wrote label map");
}

private (NDArray, NDArray) SliceBatch(NDArray X, NDArray Y, int start, int end)
{
    var s = new Slice(start, end);
    return (X[s], Y[s]);
}

private unsafe (NDArray, NDArray) NextBatch(Session sess, string[] paths, NDArray oneHotY, int start, int end)
{
    var bx = np.zeros((end - start, img_h, img_w, n_channels));
    int n = 0;
    for (int i = start; i < end; i++)
    {
        NDArray img = cv2.imread(paths[i], IMREAD_COLOR.IMREAD_GRAYSCALE); // returns uint8
        bx[n++] = sess.run(normalized, (imageInput, img));
    }
    var s = new Slice(start, end);
    var by = oneHotY[s];
    return (bx, by);
}

Evaluation and per-sample output

public void Evaluate(Session sess)
{
    var (tloss, tacc) = sess.run((loss, accuracy), (features, x_test), (targets, y_test));
    testLoss = tloss;
    testAccuracy = tacc;

    print("---------------------------------------------------------");
    print($"test_loss: {testLoss:0.0000}, test_acc: {testAccuracy:P}");
    print("---------------------------------------------------------");

    (testPredClass, testProb) = sess.run((predictedClass, probabilities), (features, x_test));
}

private void EmitPerSampleReport()
{
    for (int i = 0; i < ArrayLabel_Test.Length; i++)
    {
        long truthIdx = ArrayLabel_Test[i];
        int predIdx   = (int)testPredClass[i];
        var prob      = testProb[i, predIdx].GetSingle();

        string truth   = Dict_Label[truthIdx];
        string predict = Dict_Label[predIdx];
        string status  = truthIdx == predIdx ? "OK" : "NG";

        print($"{i + 1}|result:{status}|real_str:{truth}|predict_str:{predict}|probability:{prob}|fileName:{ArrayFileName_Test[i]}");
    }
}
Tags: csharp.NET

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.