Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
5eebcf7
feat(aggregation): add GradVac aggregator
rkhosrowshahi Apr 9, 2026
a588c93
chore: Remove outdated doctesting stuff (#639)
ValerianRey Apr 11, 2026
9d65f63
chore: Add governance documentation (#637)
PierreQuinton Apr 11, 2026
3ab336c
refactor(gradvac): literal group types, eps/beta rules, and plotter UX
rkhosrowshahi Apr 12, 2026
e53849e
refactor(gradvac): base on GramianWeightedAggregator with GradVacWeig…
rkhosrowshahi Apr 12, 2026
4909964
fix: update type hint for update_gradient_coordinate function
rkhosrowshahi Apr 12, 2026
a39f343
test(gradvac): cover beta setter success path for codecov
rkhosrowshahi Apr 12, 2026
0359e60
Rename some variables in test_gradvac.py
ValerianRey Apr 12, 2026
1da5f6e
Add comment about why we move to cpu
ValerianRey Apr 12, 2026
21d55f9
Add GradVac to the aggregator table in README
ValerianRey Apr 12, 2026
17b1dd5
Add changelog entry
ValerianRey Apr 12, 2026
02a826b
Merge branch 'main' into feature/gradvac
ValerianRey Apr 12, 2026
f4e8e60
Remove seed setting in test_aggregator_output
ValerianRey Apr 12, 2026
75c89c1
fix(aggregation): Add fallback in NashMTL (#640)
ValerianRey Apr 13, 2026
b100c8b
Merge branch 'main' into feature/gradvac
ValerianRey Apr 13, 2026
193ffa6
Merge branch 'main' of https://github.com/TorchJD/torchjd into featur…
rkhosrowshahi Apr 13, 2026
9ffdd13
Revert plot test refactors; keep GradVac in interactive plotter
rkhosrowshahi Apr 13, 2026
50525a1
Merge branch 'main' into feature/gradvac (21f6b74)
rkhosrowshahi Apr 13, 2026
e626475
docs(aggregation): add grouping usage example and fix GradVac note
rkhosrowshahi Apr 13, 2026
a244d2b
docs(changelog): split Unreleased into Added and Fixed for GradVac an…
rkhosrowshahi Apr 13, 2026
1933dea
Merge branch 'main' into feature/gradvac
rkhosrowshahi Apr 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ changelog does not include internal changes that do not affect the user.

## [Unreleased]

### Added

- Added `GradVac` and `GradVacWeighting` from
[Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874).

### Fixed

- Added a fallback for when the inner optimization of `NashMTL` fails (which can happen for example
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ TorchJD provides many existing aggregators from the literature, listed in the fo
| [Constant](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.Constant) | [ConstantWeighting](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.ConstantWeighting) | - |
| [DualProj](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProj) | [DualProjWeighting](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProjWeighting) | [Gradient Episodic Memory for Continual Learning](https://arxiv.org/pdf/1706.08840) |
| [GradDrop](https://torchjd.org/stable/docs/aggregation/graddrop#torchjd.aggregation.GradDrop) | - | [Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout](https://arxiv.org/pdf/2010.06808) |
| [GradVac](https://torchjd.org/stable/docs/aggregation/gradvac#torchjd.aggregation.GradVac) | [GradVacWeighting](https://torchjd.org/stable/docs/aggregation/gradvac#torchjd.aggregation.GradVacWeighting) | [Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874) |
| [IMTLG](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLG) | [IMTLGWeighting](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLGWeighting) | [Towards Impartial Multi-task Learning](https://discovery.ucl.ac.uk/id/eprint/10120667/) |
| [Krum](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.Krum) | [KrumWeighting](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.KrumWeighting) | [Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent](https://proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf) |
| [Mean](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.Mean) | [MeanWeighting](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.MeanWeighting) | - |
Expand Down
14 changes: 14 additions & 0 deletions docs/source/docs/aggregation/gradvac.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
:hide-toc:

GradVac
=======

.. autoclass:: torchjd.aggregation.GradVac
:members:
:undoc-members:
:exclude-members: forward

.. autoclass:: torchjd.aggregation.GradVacWeighting
:members:
:undoc-members:
:exclude-members: forward
1 change: 1 addition & 0 deletions docs/source/docs/aggregation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Abstract base classes
dualproj.rst
flattening.rst
graddrop.rst
gradvac.rst
imtl_g.rst
krum.rst
mean.rst
Expand Down
167 changes: 167 additions & 0 deletions docs/source/examples/grouping.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
Grouping
========

When applying a conflict-resolving aggregator such as :class:`~torchjd.aggregation.GradVac` in
multi-task learning, the cosine similarities between task gradients can be computed at different
granularities. The GradVac paper introduces four strategies, each partitioning the shared
parameter vector differently:

1. **Whole Model** (default) — one group covering all shared parameters.
2. **Encoder-Decoder** — one group per top-level sub-network (e.g. encoder and decoder separately).
3. **All Layers** — one group per leaf module of the encoder.
4. **All Matrices** — one group per individual parameter tensor.

In TorchJD, grouping is achieved by calling :func:`~torchjd.autojac.jac_to_grad` once per group
after :func:`~torchjd.autojac.mtl_backward`, with a dedicated aggregator instance per group.
For stateful aggregators such as :class:`~torchjd.aggregation.GradVac`, each instance
independently maintains its own EMA state :math:`\hat{\phi}`, matching the per-block targets from
the original paper.

.. note::
The grouping is orthogonal to the choice of
:func:`~torchjd.autojac.backward` vs :func:`~torchjd.autojac.mtl_backward`. Those functions
determine *which* parameters receive Jacobians; grouping then determines *how* those Jacobians
are partitioned for aggregation. Calling :func:`~torchjd.autojac.jac_to_grad` once on all shared
parameters corresponds to the Whole Model strategy. Splitting those parameters into
sub-networks and calling :func:`~torchjd.autojac.jac_to_grad` separately on each — with a
dedicated aggregator per sub-network — gives an arbitrary custom grouping, such as the
Encoder-Decoder strategy described in the GradVac paper for encoder-decoder architectures.

.. note::
The examples below use :class:`~torchjd.aggregation.GradVac`, but the same pattern applies to
any aggregator.

1. Whole Model
--------------

A single :class:`~torchjd.aggregation.GradVac` instance aggregates all shared parameters
together. Cosine similarities are computed between the full task gradient vectors.

.. testcode::
:emphasize-lines: 14, 19

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import GradVac
from torchjd.autojac import jac_to_grad, mtl_backward

encoder = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_head, task2_head = Linear(3, 1), Linear(3, 1)
optimizer = SGD([*encoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1)
loss_fn = MSELoss()
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)

gradvac = GradVac()

for x, y1, y2 in zip(inputs, t1, t2):
features = encoder(x)
mtl_backward([loss_fn(task1_head(features), y1), loss_fn(task2_head(features), y2)], features=features)
jac_to_grad(encoder.parameters(), gradvac)
optimizer.step()
optimizer.zero_grad()

2. Encoder-Decoder
------------------

One :class:`~torchjd.aggregation.GradVac` instance per top-level sub-network. Here the model
is split into an encoder and a decoder; cosine similarities are computed separately within each.
Passing ``features=dec_out`` to :func:`~torchjd.autojac.mtl_backward` causes both sub-networks
to receive Jacobians, which are then aggregated independently.

.. testcode::
:emphasize-lines: 8-9, 15-16, 22-23

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import GradVac
from torchjd.autojac import jac_to_grad, mtl_backward

encoder = Sequential(Linear(10, 5), ReLU())
decoder = Sequential(Linear(5, 3), ReLU())
task1_head, task2_head = Linear(3, 1), Linear(3, 1)
optimizer = SGD([*encoder.parameters(), *decoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1)
loss_fn = MSELoss()
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)

encoder_gradvac = GradVac()
decoder_gradvac = GradVac()

for x, y1, y2 in zip(inputs, t1, t2):
enc_out = encoder(x)
dec_out = decoder(enc_out)
mtl_backward([loss_fn(task1_head(dec_out), y1), loss_fn(task2_head(dec_out), y2)], features=dec_out)
jac_to_grad(encoder.parameters(), encoder_gradvac)
jac_to_grad(decoder.parameters(), decoder_gradvac)
optimizer.step()
optimizer.zero_grad()

3. All Layers
-------------

One :class:`~torchjd.aggregation.GradVac` instance per leaf module. Cosine similarities are
computed between the per-layer blocks of the task gradients.

.. testcode::
:emphasize-lines: 14-15, 20-21

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import GradVac
from torchjd.autojac import jac_to_grad, mtl_backward

encoder = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_head, task2_head = Linear(3, 1), Linear(3, 1)
optimizer = SGD([*encoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1)
loss_fn = MSELoss()
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)

leaf_layers = [m for m in encoder.modules() if not list(m.children()) and list(m.parameters())]
gradvacs = [GradVac() for _ in leaf_layers]

for x, y1, y2 in zip(inputs, t1, t2):
features = encoder(x)
mtl_backward([loss_fn(task1_head(features), y1), loss_fn(task2_head(features), y2)], features=features)
for layer, gradvac in zip(leaf_layers, gradvacs):
jac_to_grad(layer.parameters(), gradvac)
optimizer.step()
optimizer.zero_grad()

4. All Matrices
---------------

One :class:`~torchjd.aggregation.GradVac` instance per individual parameter tensor. Cosine
similarities are computed between the per-tensor blocks of the task gradients (e.g. weights and
biases of each layer are treated as separate groups).

.. testcode::
:emphasize-lines: 14-15, 20-21

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import GradVac
from torchjd.autojac import jac_to_grad, mtl_backward

encoder = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_head, task2_head = Linear(3, 1), Linear(3, 1)
optimizer = SGD([*encoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1)
loss_fn = MSELoss()
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)

shared_params = list(encoder.parameters())
gradvacs = [GradVac() for _ in shared_params]

for x, y1, y2 in zip(inputs, t1, t2):
features = encoder(x)
mtl_backward([loss_fn(task1_head(features), y1), loss_fn(task2_head(features), y2)], features=features)
for param, gradvac in zip(shared_params, gradvacs):
jac_to_grad([param], gradvac)
optimizer.step()
optimizer.zero_grad()
4 changes: 4 additions & 0 deletions docs/source/examples/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ This section contains some usage examples for TorchJD.
- :doc:`PyTorch Lightning Integration <lightning_integration>` showcases how to combine
TorchJD with PyTorch Lightning, by providing an example implementation of a multi-task
``LightningModule`` optimized by Jacobian descent.
- :doc:`Grouping <grouping>` shows how to apply an aggregator independently per parameter group
(e.g. per layer), so that conflict resolution happens at a finer granularity than the full
shared parameter vector.
- :doc:`Automatic Mixed Precision <amp>` shows how to combine mixed precision training with TorchJD.

.. toctree::
Expand All @@ -43,3 +46,4 @@ This section contains some usage examples for TorchJD.
monitoring.rst
lightning_integration.rst
amp.rst
grouping.rst
3 changes: 3 additions & 0 deletions src/torchjd/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from ._dualproj import DualProj, DualProjWeighting
from ._flattening import Flattening
from ._graddrop import GradDrop
from ._gradvac import GradVac, GradVacWeighting
from ._imtl_g import IMTLG, IMTLGWeighting
from ._krum import Krum, KrumWeighting
from ._mean import Mean, MeanWeighting
Expand All @@ -92,6 +93,8 @@
"Flattening",
"GeneralizedWeighting",
"GradDrop",
"GradVac",
"GradVacWeighting",
"IMTLG",
"IMTLGWeighting",
"Krum",
Expand Down
Loading
Loading