Source code for upy.operators

# Copyright (c) 2010 Friedrich Romstedt <friedrichromstedt@gmail.com>
# See also <www.friedrichromstedt.org> (if e-mail has changed)
# 
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# 
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

# Developed since: Feb 2010

import numpy
import upy.core

__all__ = ['add', 'subtract', 'multiply', 'divide', 'power', \
        'less', 'less_equal', 'greater', 'greater_equal', 'equal', \
        'not_equal']

"""Overloads the numpy operators in such a way, that in expressions
undarrays have the highest precedence."""

#
# Explicit operators ...
#
# Consider the expression:
#
#   numpyarray * upyarray
#
# When executing this, NUMPYARRAY.__mul__() is called, or, equivalently,
# numpy.multiply().  This function checkes whether the other operand is an
# numpy.ndarray, and if not, it treats it as scalar and applies the operation
# to all elements of the numpy.ndararray NUMPYARRAY.  This is not what was
# expected.  The call executes properly if upyarray.__rmul__() is being 
# called, which is done by the wrapper functions below.  The wrapper
# functions only handle this special case, all other cases are handed over to
# numpy functions.  The wrapper functions are registered in numpy via
# numpy.set_arithmetic_ops().

# Arithmetic operators ...

# We store the original numpy settings, then create the callable objects,
# which take their .ufunc attribute from this array.
original_numpy_ops = numpy.set_numeric_ops()

[docs]class ufuncWrap: """Wraps numpy ufuncs. Behaves like the original, with the exception that __call__() will be overloaded."""
[docs] def __init__(self, ufunc_name, overload): """UFUNC is the ufunc to be wrapped. OVERLOAD is the name (string) of the undarray method to be used in overloading __call__().""" self.ufunc_name = ufunc_name self.ufunc = original_numpy_ops[ufunc_name] self.overload = overload
[docs] def __call__(self, a, b, *args, **kwargs): """When B is an undarray, call B.overload(a), else .ufunc(a, b).""" if isinstance(b, upy.core.undarray): return getattr(b, self.overload)(a) else: return self.ufunc(a, b, *args, **kwargs)
[docs] def __getattr__(self, attr): """Return getattr(.ufunc, ATTR).""" return getattr(self.ufunc, attr)
[docs] def __str__(self): return "(ufunc wrapper for %s)" % self.ufunc
[docs] def __repr__(self): return "ufuncWrap(ufunc_name = %r, overload = %r)" % \ (self.ufunc_name, self.overload)
[docs]class Add(ufuncWrap): def __init__(self): ufuncWrap.__init__(self, 'add', '__radd__')
add = Add()
[docs]class Subtract(ufuncWrap): def __init__(self): ufuncWrap.__init__(self, 'subtract', '__rsub__')
subtract = Subtract()
[docs]class Multiply(ufuncWrap): def __init__(self): ufuncWrap.__init__(self, 'multiply', '__rmul__')
multiply = Multiply()
[docs]class Divide(ufuncWrap): def __init__(self): ufuncWrap.__init__(self, 'divide', '__rdiv__')
divide = Divide()
[docs]class Power(ufuncWrap): def __init__(self): ufuncWrap.__init__(self, 'power', '__rpow__')
power = Power() # Comparison operators ... # # Note that for the antisymmetric operators the called operators are the # inverted of the original due to position swap.
[docs]class Less(ufuncWrap): def __init__(self): ufuncWrap.__init__(self, 'less', '__gt__')
less = Less()
[docs]class LessEqual(ufuncWrap): def __init__(self): ufuncWrap.__init__(self, 'less_equal', '__ge__')
less_equal = LessEqual()
[docs]class Greater(ufuncWrap): def __init__(self): ufuncWrap.__init__(self, 'greater', '__lt__')
greater = Greater()
[docs]class GreaterEqual(ufuncWrap): def __init__(self): ufuncWrap.__init__(self, 'greater_equal', '__le__')
greater_equal = GreaterEqual()
[docs]class Equal(ufuncWrap): def __init__(self): ufuncWrap.__init__(self, 'equal', '__eq__')
[docs] def __call__(self, a, b, *args, **kwargs): # numpy's calling mechanism of equal() seems to have a bug, # such that b is always a numpy.ndarray. When b should be an undarray, # it is a numpy.ndarray(dtype = numpy.object, shape = ()) ... # Make the call also compatible with future, bug-fixed versions. if isinstance(b, numpy.ndarray): if b.ndim == 0: # Implement some conversion from scalar array to stored object. b = b.sum() return ufuncWrap.__call__(self, a, b, *args, **kwargs)
equal = Equal()
[docs]class NotEqual(ufuncWrap): def __init__(self): ufuncWrap.__init__(self, 'not_equal', '__ne__')
[docs] def __call__(self, a, b, *args, **kwargs): # numpy's calling mechanism of not_equal() seems to have a bug, # such that b is always a numpy.ndarray. When b should be an undarray, # it is a numpy.ndarray(dtype = numpy.object, shape = ()) ... # Make the call also compatible with future, bug-fixed versions. if isinstance(b, numpy.ndarray): if b.ndim == 0: # Implement some conversion from scalar array to stored object. b = b.sum() return ufuncWrap.__call__(self, a, b, *args, **kwargs)
not_equal = NotEqual() # Register the operators in numpy ... numpy.set_numeric_ops( add = add, subtract = subtract, multiply = multiply, divide = divide, power = power, less = less, less_equal = less_equal, greater = greater, greater_equal = greater_equal, equal = equal, not_equal = not_equal)