Operator Overloading

Operators +, -, *, \, etc., can be performed on built-in numerical data types (as well as some other types). Classes are meant to specify a representation of user-defined data types. So what operators can be applied to members of user defined class? Firstly, they have attributes that can be accessed by the operator ``.''. Besides the dot operator, if addition, subtraction, multiplication, division and other operations are appropriate for the concept implemented by the classes, they can be made available to be used on the class members directly. This mechanism is called operator overloading.

The operators that can be overloaded are +, -, *, /, .*, ./, ^ , [], ., including the dot operator for attribute selection. To overload an operator, a member attribute with a specific name needs to be defined. The operator and the attribute names are given in table (7.1).

Table 7.1: Overloaded Operators
Operator Attribute Name
+ _sum
- _sub
* _mul
/ _divide
\ _backdivide
[] _sqbracket
[,] _sqbracket2
^ _power
.* _dotmul
./ _dotdivide
. _field

If class s is defined as follows

s = class
         common _sum = y -> z
			          ...
                end

         common _sqbracket = i -> z
			          ...
                end

         common _dotmul = y -> z
			          ...
                end

         common _field = str -> z
			          ...
                end
    end
Then, if x is a member of class s, x+y is the outcome of x._sum(y), x[i] is the outcome of x._sqbracket(i), x[i, j] is the outcome of x._sqbracket2(i, j), x .* y is the outcome of x._dotmul(y), and x.str is the outcome of x._field(str).

The following example uses overloading to define a sparse matrix class. Note that in Shang does have built-in support for sparse matrix, which much more efficient and feature rich.

sparsemat = class
     private index = [1, 1];
     private y = 0; 

     readonly nrows = 10;
     readonly ncolumns = 10;
     readonly size = 100;
     readonly nzn = 0;

     new = function (nrows, ncolumns) -> ()
                 size = nrows * ncolumns;
           end
     
     common _subasgn = function (n, y) -> ()
               if n >= 1 && n <= parent.size
                    indx = [fix((n - 1) / parent.ncolumns) + 1, (n - 1) % parent.ncolumns + 1];
                    for k = 1 : parent.nzn
                          if parent.index[k,:] == indx
                                if y == 0
                                    parent.index[k,:] = parent.index[parent.nzn, :];
                                    parent.y[k] = parent.y[parent.nzn];
                                    --parent.nzn;
                                else
                                    parent.y[k] = y;
                                end
                                return;
                          end
                    end

                    ++parent.nzn;
                    parent.index[parent.nzn, :] = indx;
                    parent.y[parent.nzn] = y;
               else
                    panic("index out of bound");
               end

            end

     common _subasgn2 = function (idx1, idx2, y) -> ()
               if idx1 >= 1 && idx1 <= parent.nrows && idx2 >= 1 && idx2 <= parent.ncolumns
                    indx = [idx1, idx2];
                    for k = 1 : parent.nzn
                           if parent.index[k,:] == indx
                                if y == 0
                                    parent.index[k,:] = parent.index[parent.nzn, :];
                                    parent.y[k] = parent.y[parent.nzn];
                                    --parent.nzn;
                                else
                                    parent.y[k] = y;
                                    return;
                                end
                           end
                    end

                    ++parent.nzn;
                    parent.index[parent.nzn, :] = indx;
                    parent.y[parent.nzn] = y;
               end
            end

     common _sqbracket = function n -> y
               if n >= 1 && n <= parent.size
                    y = 0;
                    indx = [fix((n - 1) / parent.ncolumns) + 1, (n - 1) % parent.ncolumns + 1];
                    for k = 1 : parent.index.nrows
                          if parent.index[k,:] == indx
                                y = parent.y[k];
                                return;
                          end
                    end
                    y = 0;
               else
                    panic("index out of bound");
               end

     end


     common _sqbracket2 = function (i, j) -> y
               if i >= 1 && i <= parent.nrows && j >= 1 && j <= parent.ncolumns
                    y = 0;
                    for k = 1 : parent.index.nrows
                          if parent.index[k,:] == [i,j]
                                y = parent.y[k];
                                return;
                          end
                    end
                    y = 0;
               else
                    panic("index out of bound");
               end

     end

     common _mul = function x -> b
          if parent.ncolumns == x.nrows && x.ncolumns == 1
               b = zeros(parent.nrows, 1);

               for k = 1 : b.nrows
                    v = 0;
                    for j = 1 : parent.ncolumns
                         v += parent[k, j] * x[j];
                    end
                    b[k] = v;
               end
          end
     end

     common _sum = function x -> b
          if x in type(parent) && parent.nrows == x.nrows && parent.ncolumns == x.ncolumns
               T = type(parent);
               b = T.new(parent.nrows, parent.ncolumns);
               for k = 1 : x.nrows
                   for j = 1 : x.ncolumns
                        b[k, j] = parent[k, j] + x[k, j];
                   end
               end
          end
     end
end

oz 2009-12-22