Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
5 changes: 2 additions & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@

name: Continuous Integration


on:
push:
branches:
- '**' # matches every branch
- main
pull_request:
branches:
- '**' # matches every branch


permissions:
Expand Down
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -236,5 +236,4 @@ cython_debug/
/docs_state/_build/
/docs_state/_static/logos/
/docs_state/changelog.md
/examples_classic/dynamics_training/data/
/docs/
/examples/dynamics_training/data/
6 changes: 2 additions & 4 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# limitations under the License.
# ==============================================================================

__version__ = "2.7.3"
__version_info__ = (2, 7, 3)
__version__ = "2.7.4"
__version_info__ = tuple(map(int, __version__.split(".")))


from brainpy import _errors as errors
Expand Down Expand Up @@ -142,7 +142,6 @@
ArrayCollector as ArrayCollector,
Collector as Collector,
)
from brainpy import state

from brainpy.deprecations import deprecation_getattr

Expand All @@ -151,7 +150,6 @@


if __name__ == '__main__':
state
connect
initialize, # weight initialization
optim, # gradient descent optimizers
Expand Down
13 changes: 7 additions & 6 deletions brainpy/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Any, Union

import brainstate
from brainpy.math.defaults import env
from brainpy.tools.dicts import DotDict

__all__ = [
Expand All @@ -40,14 +41,14 @@ def __init__(self):

@property
def dt(self):
return brainstate.environ.get_dt()
return brainstate.environ.get_dt(env=env)

@dt.setter
def dt(self, dt):
self.set_dt(dt)

def set_dt(self, dt: Union[int, float]):
brainstate.environ.set(dt=dt)
brainstate.environ.set(dt=dt, env=env)

def load(self, key, value: Any = None, desc: str = None):
"""Load the shared data by the ``key``.
Expand All @@ -57,16 +58,16 @@ def load(self, key, value: Any = None, desc: str = None):
value (Any): the default value when ``key`` is not defined in the shared.
desc: (str): the description of the key.
"""
return brainstate.environ.get(key, value, desc)
return brainstate.environ.get(key, value, desc, env=env)

def save(self, *args, **kwargs) -> None:
"""Save shared arguments in the global context."""
assert len(args) % 2 == 0
for i in range(0, len(args), 2):
identifier = args[i]
data = args[i + 1]
brainstate.environ.set(**{identifier: data})
brainstate.environ.set(**kwargs)
brainstate.environ.set(**{identifier: data}, env=env)
brainstate.environ.set(**kwargs, env=env)

def __setitem__(self, key, value):
"""Enable setting the shared item by ``bp.share[key] = value``."""
Expand All @@ -78,7 +79,7 @@ def __getitem__(self, item):

def get_shargs(self) -> DotDict:
"""Get all shared arguments in the global context."""
return DotDict(brainstate.environ.all())
return DotDict(brainstate.environ.all(env=env))


share = _ShareContext()
2 changes: 1 addition & 1 deletion brainpy/dyn/others/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def __init__(
self.reset_state(self.mode)

def update(self):
spikes = bm.random.rand_like(self.spike) <= (self.freqs * share['dt'] / 1000.)
spikes = bm.random.rand_like(self.spike.value) <= (self.freqs * share['dt'] / 1000.)
spikes = bm.asarray(spikes, dtype=self.spk_type)
# import jax
# jax.debug.print('PoissonGroup: freqs = {f}, spikes = {s}', f=self.freqs, s=spikes)
Expand Down
17 changes: 9 additions & 8 deletions brainpy/inputs/currents.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import braintools

import brainstate
import brainpy.math

__all__ = [
'section_input',
Expand Down Expand Up @@ -59,7 +60,7 @@ def section_input(values, durations, dt=None, return_length=False):

current_and_duration
"""
with brainstate.environ.context(dt=brainstate.environ.get_dt() if dt is None else dt):
with brainstate.environ.context(dt=brainpy.math.get_dt() if dt is None else dt):
return braintools.input.section(values, durations, return_length=return_length)


Expand Down Expand Up @@ -88,7 +89,7 @@ def constant_input(I_and_duration, dt=None):
current_and_duration : tuple
(The formatted current, total duration)
"""
with brainstate.environ.context(dt=brainstate.environ.get_dt() if dt is None else dt):
with brainstate.environ.context(dt=brainpy.math.get_dt() if dt is None else dt):
return braintools.input.constant(I_and_duration)


Expand Down Expand Up @@ -136,7 +137,7 @@ def spike_input(sp_times, sp_lens, sp_sizes, duration, dt=None):
current : bm.ndarray
The formatted input current.
"""
with brainstate.environ.context(dt=brainstate.environ.get_dt() if dt is None else dt):
with brainstate.environ.context(dt=brainpy.math.get_dt() if dt is None else dt):
return braintools.input.spike(sp_times, sp_lens, sp_sizes, duration)


Expand Down Expand Up @@ -175,7 +176,7 @@ def ramp_input(c_start, c_end, duration, t_start=0, t_end=None, dt=None):
current : bm.ndarray
The formatted current
"""
with brainstate.environ.context(dt=brainstate.environ.get_dt() if dt is None else dt):
with brainstate.environ.context(dt=brainpy.math.get_dt() if dt is None else dt):
return braintools.input.ramp(c_start, c_end, duration, t_start, t_end)


Expand Down Expand Up @@ -210,7 +211,7 @@ def wiener_process(duration, dt=None, n=1, t_start=0., t_end=None, seed=None):
seed: int
The noise seed.
"""
with brainstate.environ.context(dt=brainstate.environ.get_dt() if dt is None else dt):
with brainstate.environ.context(dt=brainpy.math.get_dt() if dt is None else dt):
return braintools.input.wiener_process(duration, sigma=1.0, n=n, t_start=t_start, t_end=t_end, seed=seed)


Expand Down Expand Up @@ -242,7 +243,7 @@ def ou_process(mean, sigma, tau, duration, dt=None, n=1, t_start=0., t_end=None,
seed: optional, int
The random seed.
"""
with brainstate.environ.context(dt=brainstate.environ.get_dt() if dt is None else dt):
with brainstate.environ.context(dt=brainpy.math.get_dt() if dt is None else dt):
return braintools.input.ou_process(mean, sigma, tau, duration, n=n, t_start=t_start, t_end=t_end, seed=seed)


Expand All @@ -267,7 +268,7 @@ def sinusoidal_input(amplitude, frequency, duration, dt=None, t_start=0., t_end=
Whether the sinusoid oscillates around 0 (False), or
has a positive DC bias, thus non-negative (True).
"""
with brainstate.environ.context(dt=brainstate.environ.get_dt() if dt is None else dt):
with brainstate.environ.context(dt=brainpy.math.get_dt() if dt is None else dt):
return braintools.input.sinusoidal(amplitude, frequency, duration, t_start=t_start, t_end=t_end, bias=bias)


Expand All @@ -292,5 +293,5 @@ def square_input(amplitude, frequency, duration, dt=None, bias=False, t_start=0.
Whether the sinusoid oscillates around 0 (False), or
has a positive DC bias, thus non-negative (True).
"""
with brainstate.environ.context(dt=brainstate.environ.get_dt() if dt is None else dt):
with brainstate.environ.context(dt=brainpy.math.get_dt() if dt is None else dt):
return braintools.input.square(amplitude, frequency, duration, t_start=t_start, t_end=t_end, duty_cycle=0.5, bias=bias)
39 changes: 21 additions & 18 deletions brainpy/math/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from .modes import NonBatchingMode
from .scales import IdScaling

env = brainstate.environ.EnvironmentState()


class setting:
def __init__(self):
Expand All @@ -42,79 +44,80 @@ def __init__(self):
# default return array type
# numpy_func_return='jax_array', # 'bp_array','jax_array'
numpy_func_return='bp_array', # 'bp_array','jax_array'
env=env,
)

@property
def mode(self):
return brainstate.environ.get('mode')
return brainstate.environ.get('mode', env=env)

@property
def membrane_scaling(self):
return brainstate.environ.get('membrane_scaling')
return brainstate.environ.get('membrane_scaling', env=env)

@property
def dt(self):
return brainstate.environ.get('dt')
return brainstate.environ.get('dt', env=env)

@property
def bool_(self):
return brainstate.environ.get('bool_')
return brainstate.environ.get('bool_', env=env)

@property
def int_(self):
return brainstate.environ.get('int_')
return brainstate.environ.get('int_', env=env)

@property
def float_(self):
return brainstate.environ.get('float_')
return brainstate.environ.get('float_', env=env)

@property
def complex_(self):
return brainstate.environ.get('complex_')
return brainstate.environ.get('complex_', env=env)

@property
def bp_object_as_pytree(self):
return brainstate.environ.get('bp_object_as_pytree')
return brainstate.environ.get('bp_object_as_pytree', env=env)

@property
def numpy_func_return(self):
return brainstate.environ.get('numpy_func_return')
return brainstate.environ.get('numpy_func_return', env=env)

@mode.setter
def mode(self, value):
brainstate.environ.set(mode=value)
brainstate.environ.set(mode=value, env=env)

@membrane_scaling.setter
def membrane_scaling(self, value):
brainstate.environ.set(membrane_scaling=value)
brainstate.environ.set(membrane_scaling=value, env=env)

@dt.setter
def dt(self, value):
brainstate.environ.set(dt=value)
brainstate.environ.set(dt=value, env=env)

@bool_.setter
def bool_(self, value):
brainstate.environ.set(bool_=value)
brainstate.environ.set(bool_=value, env=env)

@int_.setter
def int_(self, value):
brainstate.environ.set(int_=value)
brainstate.environ.set(int_=value, env=env)

@float_.setter
def float_(self, value):
brainstate.environ.set(float_=value)
brainstate.environ.set(float_=value, env=env)

@complex_.setter
def complex_(self, value):
brainstate.environ.set(complex_=value)
brainstate.environ.set(complex_=value, env=env)

@bp_object_as_pytree.setter
def bp_object_as_pytree(self, value):
brainstate.environ.set(bp_object_as_pytree=value)
brainstate.environ.set(bp_object_as_pytree=value, env=env)

@numpy_func_return.setter
def numpy_func_return(self, value):
brainstate.environ.set(numpy_func_return=value)
brainstate.environ.set(numpy_func_return=value, env=env)


defaults = setting()
45 changes: 0 additions & 45 deletions brainpy/state/__init__.py

This file was deleted.

Loading
Loading