Radial Basis Function Networks Java

What is A Radial Basis Function Network

This is a type of Neural Network that performs a similar job to a Multilayer Neural Network. That is it can analyze higher dimensional data.


How it actually works

What this actually does is place n points across your data set. Then taking an assigned original equation it changes certain terms in the equation so that the curve of the changed equation represents the data set.

After training, it decides how best to match the data based on the function of those n points

So lets say you had to solve the EX-OR problem

The truth table would be as follows and you had to make a neural net learn it. Obviously there is no straight line that will solve your problem. So you have to draw a curved line that fits the data. So RBF can help. Lets say we put two points on the input space and draw 2 Gaussian curves atthose points.

Well now our data got split up quite nicely.

That's basically how RBF works.


Training is going to be done by one of the simplest techniques. This is known as the Pseudo-Inverse Technique


public class rbfn {
public static void main(String a[]){
    Net n=new rbfn().new Net();
    for(int i=0;i<100;i++){
    n.test(new double[]{0,0});
    n.test(new double[]{0,1});
    n.test(new double[]{1,0});
    n.test(new double[]{1,1});


class Net{
    double[][] inputs=new double[][]{{0,0},{0,1},{1,0},{1,1}};
    double[] outputs= {1,0,0,1};
    Unit[] net =new Unit[2];

        net[0]=new Unit(0.5,new double[]{0,0},0.5);
        net[1]=new Unit(0.5,new double[]{1,1},0.5);

    void train(){
        for(int i=0;i<inputs.length;i++){
            double output=outputs[i];
            double predictedoutput=0;
            for(int j=0;j<inputs[i].length;j++){

            //predictedoutput= Math.round(predictedoutput);

            for(int j=0;j<inputs[i].length;j++){
                net[j].update(inputs[i], output, predictedoutput);

    void test(double[] inputs){
        double predictedOutput=0;
        for(int i=0;i<inputs.length;i++){
            System.out.print(net[i].w +"\t"+net[i].c[0]+"\t"+net[i].c[1]+"\t");
        for(int i=0;i<inputs.length;i++){


class Unit{
    double w;
    double c[];
    double sigma;
    double n1=0.1;
    double n2=0.1;
    Unit(double sigma,double[] center,double weight){


    double phi(double[] input){
        double distance=0;
        for(int i=0;i<c.length;i++)
        return Math.pow(Math.E,- distance/(2*Math.pow(sigma, 2)));

    void update(double[] input,double desired,double output){
        double phi=phi(input);
        double diffOutput=desired-output;

        for(int i=0;i<c.length;i++)
            c[i]=c[i]+ (n1*diffOutput*w* phi*(input[i]-c[i])/(sigma*sigma));


07 May 2014