mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-07 20:41:20 +08:00
68 lines
1.8 KiB
Python
68 lines
1.8 KiB
Python
"""Module for common feature initializers."""
|
|
from __future__ import absolute_import
|
|
|
|
from . import backend as F
|
|
|
|
__all__ = ["base_initializer", "zero_initializer"]
|
|
|
|
|
|
def base_initializer(
|
|
shape, dtype, ctx, id_range
|
|
): # pylint: disable=unused-argument
|
|
"""The function signature for feature initializer.
|
|
|
|
Any customized feature initializer should follow this signature (see
|
|
example below).
|
|
|
|
Parameters
|
|
----------
|
|
shape : tuple of int
|
|
The shape of the result features. The first dimension
|
|
is the batch dimension.
|
|
dtype : data type object
|
|
The data type of the returned features.
|
|
ctx : context object
|
|
The device context of the returned features.
|
|
id_range : slice
|
|
The start id and the end id of the features to be initialized.
|
|
The id could be node or edge id depending on the scenario.
|
|
Note that the step is always None.
|
|
|
|
Examples
|
|
--------
|
|
If PyTorch is used as backend, the following code defines an feature
|
|
initializer that initializes tensor value to 1
|
|
|
|
>>> import torch
|
|
>>> import dgl
|
|
>>> def initializer(shape, dtype, ctx, id_range):
|
|
>>> return torch.ones(shape, dtype=dtype, device=ctx)
|
|
>>> g = dgl.DGLGraph()
|
|
>>> g.set_n_initializer(initializer)
|
|
|
|
See Also
|
|
--------
|
|
dgl.DGLGraph.set_n_initializer
|
|
dgl.DGLGraph.set_e_initializer
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
def zero_initializer(
|
|
shape, dtype, ctx, id_range
|
|
): # pylint: disable=unused-argument
|
|
"""Zero feature initializer
|
|
|
|
Examples
|
|
--------
|
|
>>> import dgl
|
|
>>> g = dgl.DGLGraph()
|
|
>>> g.set_n_initializer(dgl.init.zero_initializer)
|
|
|
|
See Also
|
|
--------
|
|
dgl.DGLGraph.set_n_initializer
|
|
dgl.DGLGraph.set_e_initializer
|
|
"""
|
|
return F.zeros(shape, dtype, ctx)
|