サブロウ丸

主にプログラミングと数学

[ python ] 自作クラスをnumpy ndarrayに使用する

Subclassing ndarray — NumPy v1.18 Manual (公式)を参考にすれば良いでしょう.

以下, 簡単な実装例.

class Element:
    def __init__(self, r, c):
        self.r = r
        self.c = c
        
    def __repr__(self):
        return f'Element({self.r},{self.c})'

上の自作クラスを, 2次元ndarrayの要素として, r行c列目にElement(r, c)を使いたいとします.

MyNdarrayクラスをnp.ndarrayを継承して以下のように作ります.

import numpy as np

class MyNdarray(np.ndarray):
    def __new__(cls, dimension):
        shape = (dimension, dimension)
        return super().__new__(cls, shape, dtype=Element)

    def __init__(self, dimension):
        for r in range(dimension):
            for c in range(dimension):
                self[r, c] = Element(r, c)


__new__ でshapeが(dimension * dimension) の numpy ndarray instanceを作成して, __init__ で初期操作(コンストラクタ)が行われます.

a = MyNdarray(3)
>> MyNdarray([[Element(0,0), Element(0,1), Element(0,2)],
           [Element(1,0), Element(1,1), Element(1,2)],
           [Element(2,0), Element(2,1), Element(2,2)]], dtype=object)
a[0, 0]
>> Element(0,0)


ndarrayのスライシング操作や転置機能をそのまま使うことができるのがいいですね.

a[:, 0:2]
>> MyNdarray([[Element(0,0), Element(0,1)],
           [Element(1,0), Element(1,1)],
           [Element(2,0), Element(2,1)]], dtype=object)
a.T
>> MyNdarray([[Element(0,0), Element(1,0), Element(2,0)],
           [Element(0,1), Element(1,1), Element(2,1)],
           [Element(0,2), Element(1,2), Element(2,2)]], dtype=object)



また, MyNdarray クラスに, 属性を追加したいときは__array_finalize__ 関数を使います.

import numpy as np

class MyNdarray(np.ndarray):
    def __new__(cls, dimension):
        shape = (dimension, dimension)
        obj = super().__new__(cls, shape, dtype=Element)
        obj.dimension = dimension
        return obj
    
    def __init__(self, dimension):
        for r in range(dimension):
            for c in range(dimension):
                self[r, c] = Element(r, c)
                
    def __array_finalize__(self, obj):
        if obj is None:
            return
        self.dimension = getattr(obj, 'dimension', None)
a = MyNdarray(3)
a.dimension
>>> 3